Compare commits

...

63 Commits
0.6.5 ... 0.6.6

Author SHA1 Message Date
93393e005e version to 0.6.6 (#4050) 2024-05-02 16:06:40 +08:00
4ea2755fce test: remove explicit env settings for CI pytests (#4041) 2024-05-02 00:49:39 +08:00
ecb51a83d4 chore(deps): bump semver from 5.7.1 to 5.7.2 in /web (#4022)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-30 18:47:05 +08:00
093b5c0e63 fix: typo of jinja2 (#4019) 2024-04-30 18:39:02 +08:00
bf42b0ae44 fix: lodash version has warning (#4020)
Co-authored-by: nite-knite <nkCoding@gmail.com>
2024-04-30 18:11:49 +08:00
342b4fd19d chore(deps): bump word-wrap from 1.2.3 to 1.2.5 in /web
Bumps [word-wrap](https://github.com/jonschlinkert/word-wrap) from 1.2.3 to 1.2.5.
- [Release notes](https://github.com/jonschlinkert/word-wrap/releases)
- [Commits](https://github.com/jonschlinkert/word-wrap/compare/1.2.3...1.2.5)

---
updated-dependencies:
- dependency-name: word-wrap
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-30 09:39:10 +00:00
cbdb861ee4 add glm-3-turbo max_tokens parameter setting (#4017)
Co-authored-by: 陈力坤 <likunchen@caixin.com>
2024-04-30 17:08:04 +08:00
da5a8b9a59 feat: support question classifier node output (#4000) 2024-04-30 17:07:29 +08:00
1e6e8b446d feat: support minimax abab6.5, abab6.5s (#4012) 2024-04-30 17:02:01 +08:00
c1fdaa6ae0 fix: prompt undefined caused match problem (#4010) 2024-04-30 16:31:36 +08:00
142814d451 chore: skip deprecated field_schema param in creating payload index on Qdrant (#3903) 2024-04-30 16:16:10 +08:00
704755d005 fix: submitCodeExecutionTask (#4006) 2024-04-30 16:01:03 +08:00
d1263700c0 Update the description and labels in Judge0ce tool (#3990)
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-30 14:58:29 +08:00
0704fe9695 fix(web): copy button visible at chat page normally (#4005)
Co-authored-by: rongjun.qiu <qiurj@hengtonggroup.com.cn>
2024-04-30 14:55:57 +08:00
1d3f1d88ef Enabled Notion integration setup in Docker Compose Deployment (#3919) 2024-04-30 14:48:39 +08:00
8b3edac091 fix: prompt editor insert quickly (#4004) 2024-04-30 14:25:21 +08:00
05cab85579 fix: workflow disable shortcuts when feature panel occured (#4001) 2024-04-30 13:35:49 +08:00
b72fbe200d chore: add sandbox tag (#3997) 2024-04-30 12:35:19 +08:00
b1194da6a5 fix: ci (#3983) 2024-04-29 18:59:37 +08:00
338e4669e5 add storage factory (#3922) 2024-04-29 18:22:03 +08:00
c5e2659771 Feat/install process refinement (#3982) 2024-04-29 17:55:52 +08:00
1d432728ac add default value for QDRANT_GRPC_PORT (#3976) 2024-04-29 15:28:34 +08:00
2fd702a319 Fix: password check in page of install (#3978) 2024-04-29 15:27:45 +08:00
f26ad16af7 Add new tool: Firecrawl (#3819)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
2024-04-29 14:20:36 +08:00
8f2ae51fe5 feat: add support for request timeout settings in the HTTP request node. (#3854)
Co-authored-by: Yeuoly <admin@srmxy.cn>
2024-04-29 13:59:07 +08:00
2f84d00300 fix-nvidia-llama3 (#3973) 2024-04-29 13:41:15 +08:00
b82a2d97ef fix: db connections not being released during workflow execution (#3971) 2024-04-29 12:42:09 +08:00
3e9dbe3e0a add pgvecto_rs support and upgrade SQLAlchemy (#3833) 2024-04-29 11:58:17 +08:00
975b2fb79e delete duplicate check get_dataset (#3966)
Co-authored-by: baxiang <baxiang@lixiang.com>
2024-04-29 11:57:26 +08:00
fa509ce64e feat: rename var name sync to used jinjia code (#3964) 2024-04-29 11:34:30 +08:00
99292edd46 chore: update @types/react (#3939) 2024-04-28 19:01:09 +08:00
3e992cb23c feat: code transform node editor support insert var by add slash or left brace (#3946)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-04-28 17:51:58 +08:00
e7b4d024ee optimize: code node has a bad error message (#3949) 2024-04-28 17:40:29 +08:00
ff67a6d338 feat: llm text stream support for workflow app (#3798)
Co-authored-by: JzoNg <jzongcode@gmail.com>
2024-04-28 17:37:00 +08:00
8e4989ed03 feat: workflow remove preview mode (#3941) 2024-04-28 17:09:56 +08:00
0940f01634 enhancement:support Qdrant gRPC mode (#3929) 2024-04-28 15:33:32 +08:00
9d1cb1bc92 improvement: Optimizing the experience of the app list page (#3885) 2024-04-28 13:52:45 +08:00
0ca4e30b19 feat: add start commands to devcontainer (#3902) 2024-04-28 12:30:56 +08:00
ba88f8a6f0 fix: code full screen in web app cause error (#3935) 2024-04-28 11:59:57 +08:00
aefe0cbf51 fix: api doc example error (#3925) 2024-04-28 10:18:07 +08:00
9ad489d133 feat: Add google storage support (#3887)
Co-authored-by: miendinh <miendinh@users.noreply.github.com>
2024-04-27 18:26:52 +08:00
661b30784e chore: skip warning messages when pytest auto-collecting the vdb test class by removing Test prefix (#3906) 2024-04-27 16:36:09 +08:00
43a5ba9415 feat: add support for Bedrock LLAMA3 (#3890) 2024-04-27 13:13:09 +08:00
08a65d74d5 fix: hydration warning (#3897) 2024-04-26 21:34:29 +08:00
cefe156811 feat: replicate supports default version. (#3884) 2024-04-26 21:16:22 +08:00
3b5b4d628b Add support for Traditional Chinese language (#3899)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-26 21:10:23 +08:00
8746e48df0 chore: integrate code-inspector-plugin (#3900) 2024-04-26 21:00:29 +08:00
0ec8b57825 add together ai model setting (#3895) 2024-04-26 20:43:17 +08:00
045827043d test: improve vector store tests (#3855) 2024-04-26 19:18:42 +08:00
4d66a86579 fix: fetch page name of notion wiki (#3847) 2024-04-26 18:04:37 +08:00
2a8881d0e8 fix: tool webscraper - too many redirects in case target url does not… (#3831)
Co-authored-by: miendinh <miendinh@users.noreply.github.com>
2024-04-26 17:58:46 +08:00
ffc60bb917 add the comment in entrypoint.sh (#3882) 2024-04-26 17:19:49 +08:00
2e454c770b fix: copy invite link for HTTPS has deplicate origin (#3877) 2024-04-26 15:19:30 +08:00
7d711135bc fix: full screen editor not follow panel width (#3876) 2024-04-26 14:23:13 +08:00
f62b2b5b45 optimize the knowledge failed documents query (#3870)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-04-26 11:47:23 +08:00
7919596a21 fix: UP031 style rule violation (#3866) 2024-04-26 11:24:08 +08:00
9b4898efeb fix: chat api doc not show title in english vision (#3864) 2024-04-26 10:32:45 +08:00
45dd1683fd test: add tests covering all methods of vector store (#3849) 2024-04-25 22:27:30 +08:00
8bca908f15 refactor: config file (#3852) 2024-04-25 22:26:45 +08:00
9cbb8ddd7f fix: billing tenant account role. (#3850) 2024-04-25 21:55:08 +08:00
1be222af2e fix: using api can not execute relyt vector database (#3766)
Co-authored-by: jingsi <jingsi@leadincloud.com>
2024-04-25 19:46:20 +08:00
bf9fc8fef4 Reduce tool redundancy for [Judge0 CE] (#3837)
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-25 19:20:54 +08:00
86e7330fa2 test: refactor vdb tests by visitor design pattern (#3838) 2024-04-25 18:55:49 +08:00
236 changed files with 7828 additions and 3156 deletions

View File

@ -32,8 +32,8 @@
]
}
},
"postStartCommand": "cd api && pip install -r requirements.txt",
"postCreateCommand": "cd web && npm install"
"postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},

View File

@ -0,0 +1,10 @@
#!/bin/bash
cd web && npm install
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -0,0 +1,3 @@
#!/bin/bash
cd api && pip install -r requirements.txt

View File

@ -14,50 +14,10 @@ jobs:
- "3.10"
- "3.11"
env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
CHATGLM_API_BASE: http://a.abc.com:11451
XINFERENCE_SERVER_URL: http://a.abc.com:11451
XINFERENCE_GENERATION_MODEL_UID: generate
XINFERENCE_CHAT_MODEL_UID: chat
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
XINFERENCE_RERANK_MODEL_UID: rerank
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true
CODE_MAX_STRING_LENGTH: 80000
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Weaviate
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.middleware.yaml
services: weaviate
- name: Set up Qdrant
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.qdrant.yaml
services: qdrant
- name: Set up Milvus
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.milvus.yaml
services: |
etcd
minio
milvus-standalone
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
@ -82,5 +42,21 @@ jobs:
- name: Run Workflow
run: dev/pytest/pytest_workflow.sh
- name: Run Vector Stores
- name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
services: |
weaviate
qdrant
etcd
minio
milvus-standalone
pgvecto-rs
- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh

View File

@ -1,6 +1,3 @@
# Server Edition
EDITION=SELF_HOSTED
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
@ -57,12 +54,15 @@ ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-string
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, relyt
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
VECTOR_STORE=weaviate
# Weaviate configuration
@ -75,6 +75,8 @@ WEAVIATE_BATCH_SIZE=100
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
QDRANT_CLIENT_TIMEOUT=20
QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
# Milvus configuration
MILVUS_HOST=127.0.0.1
@ -90,6 +92,13 @@ RELYT_USER=postgres
RELYT_PASSWORD=postgres
RELYT_DATABASE=postgres
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431
PGVECTO_RS_USER=postgres
PGVECTO_RS_PASSWORD=difyai123456
PGVECTO_RS_DATABASE=postgres
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
@ -123,25 +132,6 @@ NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
# Hosted Model Credentials
HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_TRIAL_ENABLED=false
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
HOSTED_AZURE_OPENAI_API_BASE=
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_TRIAL_ENABLED=false
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
ETL_TYPE=dify
UNSTRUCTURED_API_URL=
@ -166,5 +156,10 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# HTTP Node configuration
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
# Log file path
LOG_FILE=
LOG_FILE=

View File

@ -1,28 +1,28 @@
import os
import sys
from logging.handlers import RotatingFileHandler
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
# if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import json
import logging
import sys
import threading
import time
import warnings
from logging.handlers import RotatingFileHandler
from flask import Flask, Response, request
from flask_cors import CORS
from werkzeug.exceptions import Unauthorized
from commands import register_commands
from config import CloudEditionConfig, Config
from config import Config
# DO NOT REMOVE BELOW
from events import event_handlers
@ -75,16 +75,9 @@ config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
# ----------------------------
def create_app(test_config=None) -> Flask:
def create_app() -> Flask:
app = DifyApp(__name__)
if test_config:
app.config.from_object(test_config)
else:
if config_type == "CLOUD":
app.config.from_object(CloudEditionConfig())
else:
app.config.from_object(Config())
app.config.from_object(Config())
app.secret_key = app.config['SECRET_KEY']
@ -101,6 +94,7 @@ def create_app(test_config=None) -> Flask:
),
logging.StreamHandler(sys.stdout)
]
logging.basicConfig(
level=app.config.get('LOG_LEVEL'),
format=app.config.get('LOG_FORMAT'),

View File

@ -5,6 +5,7 @@ import dotenv
dotenv.load_dotenv()
DEFAULTS = {
'EDITION': 'SELF_HOSTED',
'DB_USERNAME': 'postgres',
'DB_PASSWORD': '',
'DB_HOST': 'localhost',
@ -36,6 +37,8 @@ DEFAULTS = {
'WEAVIATE_GRPC_ENABLED': 'True',
'WEAVIATE_BATCH_SIZE': 100,
'QDRANT_CLIENT_TIMEOUT': 20,
'QDRANT_GRPC_ENABLED': 'False',
'QDRANT_GRPC_PORT': '6334',
'CELERY_BACKEND': 'database',
'LOG_LEVEL': 'INFO',
'LOG_FILE': '',
@ -104,9 +107,9 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.6.5"
self.CURRENT_VERSION = "0.6.6"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.EDITION = get_env('EDITION')
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL')
@ -212,6 +215,8 @@ class Config:
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
# ------------------------
# Vector Store Configurations.
@ -223,6 +228,8 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
self.QDRANT_GRPC_ENABLED = get_env('QDRANT_GRPC_ENABLED')
self.QDRANT_GRPC_PORT = get_env('QDRANT_GRPC_PORT')
# milvus / zilliz setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
@ -245,6 +252,13 @@ class Config:
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
self.RELYT_DATABASE = get_env('RELYT_DATABASE')
# pgvecto rs settings
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
# ------------------------
# Mail Configurations.
# ------------------------
@ -260,7 +274,7 @@ class Config:
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
# ------------------------
# Workpace Configurations.
# Workspace Configurations.
# ------------------------
self.INVITE_EXPIRY_HOURS = int(get_env('INVITE_EXPIRY_HOURS'))
@ -299,6 +313,12 @@ class Config:
# ------------------------
# Platform Configurations.
# ------------------------
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
@ -345,17 +365,3 @@ class Config:
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
class CloudEditionConfig(Config):
def __init__(self):
super().__init__()
self.EDITION = "CLOUD"
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')

View File

@ -1,10 +1,11 @@
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
language_timezone_mapping = {
'en-US': 'America/New_York',
'zh-Hans': 'Asia/Shanghai',
'zh-Hant': 'Asia/Taipei',
'pt-BR': 'America/Sao_Paulo',
'es-ES': 'Europe/Madrid',
'fr-FR': 'Europe/Paris',

View File

@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus':
if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
return {
'retrieval_method': [
'semantic_search'
@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
if vector_type == 'milvus':
if vector_type == 'milvus' or vector_type == 'relyt':
return {
'retrieval_method': [
'semantic_search'

View File

@ -394,9 +394,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
def get(self, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
documents = self.get_batch_documents(dataset_id, batch)
response = {
"tokens": 0,

View File

@ -28,9 +28,9 @@ from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageEndStreamResponse,
StreamGenerateRoute,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -343,7 +343,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
@ -366,7 +366,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = StreamGenerateRoute(
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
@ -430,15 +430,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(token)
if should_direct_answer:
continue
self._task_state.answer += token
yield self._message_to_stream_response(token, self._message.id)
time.sleep(0.01)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
@ -463,10 +462,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._task_state.answer += token
yield self._message_to_stream_response(token, self._message.id)
time.sleep(0.01)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector

View File

@ -28,11 +28,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -40,6 +42,7 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
)
@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
}
self._task_state = WorkflowTaskState()
self._stream_generate_nodes = self._get_stream_generate_nodes()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -167,6 +171,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
@ -174,6 +186,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
@ -193,6 +206,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
@ -254,3 +272,142 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True

View File

@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
@ -119,7 +120,15 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
"""
Publish text chunk
"""
pass
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""

View File

@ -9,9 +9,17 @@ from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
class StreamGenerateRoute(BaseModel):
class WorkflowStreamGenerateNodes(BaseModel):
"""
StreamGenerateRoute entity
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
@ -55,6 +63,8 @@ class WorkflowTaskState(TaskState):
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
@ -62,7 +72,7 @@ class AdvancedChatTaskState(WorkflowTaskState):
"""
usage: LLMUsage
current_stream_generate_state: Optional[StreamGenerateRoute] = None
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):

View File

@ -6,7 +6,7 @@ from yarl import URL
from config import get_env
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python_transformer import PythonTemplateTransformer
# Code Executor

View File

@ -55,6 +55,7 @@ if __name__ == '__main__':
"""
class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:

View File

@ -8,6 +8,8 @@
- anthropic.claude-3-haiku-v1:0
- cohere.command-light-text-v14
- cohere.command-text-v14
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2402-v1:0

View File

@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
if isinstance(messages, str):
prompt = messages
else:
prompt = self._convert_messages_to_prompt(messages, prefix)
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(
self._convert_one_message_to_text(message, model_prefix)
for message in messages
)
return text.rstrip()
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
"""
Convert a single message to a string.
@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
ai_prompt = "\n\nAssistant:"
elif model_prefix == "meta":
human_prompt_prefix = "\n[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = ""
# LLAMA3
if model_name.startswith("llama3"):
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
ai_prompt = "\n\nAssistant:"
else:
# LLAMA2
human_prompt_prefix = "\n[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = ""
elif model_prefix == "mistral":
human_prompt_prefix = "<s>[INST]"
human_prompt_postfix = "[\\INST]\n"
@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
:param messages: List of PromptMessage to combine.
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
messages.append(AssistantPromptMessage(content=""))
text = "".join(
self._convert_one_message_to_text(message, model_prefix)
self._convert_one_message_to_text(message, model_prefix, model_name)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
"""
Create payload for bedrock api call depending on model provider
"""
payload = dict()
model_prefix = model.split('.')[0]
model_name = model.split('.')[1]
if model_prefix == "amazon":
payload["textGenerationConfig"] = { **model_parameters }
@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "meta":
payload = { **model_parameters }
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
else:
raise ValueError(f"Got unknown model prefix {model_prefix}")
@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
)
model_prefix = model.split('.')[0]
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
# need workaround for ai21 models which doesn't support streaming
if stream and model_prefix != "ai21":

View File

@ -0,0 +1,23 @@
model: meta.llama3-70b-instruct-v1:0
label:
en_US: Llama 3 Instruct 70B
model_type: llm
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00265'
output: '0.0035'
unit: '0.00001'
currency: USD

View File

@ -0,0 +1,23 @@
model: meta.llama3-8b-instruct-v1:0
label:
en_US: Llama 3 Instruct 8B
model_type: llm
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.0004'
output: '0.0006'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,37 @@
model: abab6.5-chat
label:
en_US: Abab6.5-Chat
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.1
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 8192
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.03'
output: '0.03'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,37 @@
model: abab6.5s-chat
label:
en_US: Abab6.5s-Chat
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 245760
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.1
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 245760
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.01'
output: '0.01'
unit: '0.001'
currency: RMB

View File

@ -1,7 +1,7 @@
- google/gemma-7b
- google/codegemma-7b
- meta/llama2-70b
- meta/llama3-8b
- meta/llama3-70b
- meta/llama3-8b-instruct
- meta/llama3-70b-instruct
- mistralai/mixtral-8x7b-instruct-v0.1
- fuyu-8b

View File

@ -1,7 +1,7 @@
model: meta/llama3-70b
model: meta/llama3-70b-instruct
label:
zh_Hans: meta/llama3-70b
en_US: meta/llama3-70b
zh_Hans: meta/llama3-70b-instruct
en_US: meta/llama3-70b-instruct
model_type: llm
features:
- agent-thought

View File

@ -1,7 +1,7 @@
model: meta/llama3-8b
model: meta/llama3-8b-instruct
label:
zh_Hans: meta/llama3-8b
en_US: meta/llama3-8b
zh_Hans: meta/llama3-8b-instruct
en_US: meta/llama3-8b-instruct
model_type: llm
features:
- agent-thought

View File

@ -26,8 +26,8 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
'google/gemma-7b': '',
'google/codegemma-7b': '',
'meta/llama2-70b': '',
'meta/llama3-8b': '',
'meta/llama3-70b': ''
'meta/llama3-8b-instruct': '',
'meta/llama3-70b-instruct': ''
}

View File

@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
version = credentials['model_version']
model_version = ''
if 'model_version' in credentials:
model_version = credentials['model_version']
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
inputs = {**model_parameters}
@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
if 'model_version' not in credentials:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
model_version = ''
if 'model_version' in credentials:
model_version = credentials['model_version']
if model.count("/") != 1:
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
'format: {user_name}/{model_name}')
version = credentials['model_version']
try:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
self._check_text_generation_model(model_info_version, model, version)
if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
except ReplicateError as e:
raise CredentialsValidateFailedError(
f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e:
raise CredentialsValidateFailedError(str(e))
@staticmethod
def _check_text_generation_model(model_info_version, model_name, version):
def _check_text_generation_model(model_info_version, model_name, version, description):
if 'language model' in description.lower():
return
if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
@classmethod
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
version = credentials['model_version']
model_version = ''
if 'model_version' in credentials:
model_version = credentials['model_version']
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
parameter_rules = []

View File

@ -35,7 +35,7 @@ model_credential_schema:
label:
en_US: Model Version
type: text-input
required: true
required: false
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version
zh_Hans: 在此输入您的模型版本,默认为最新版本
en_US: Enter your model version, default to the latest version

View File

@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
user: Optional[str] = None) -> TextEmbeddingResult:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
if 'model_version' in credentials:
model_version = credentials['model_version']
else:
model_info = client.models.get(model)
model_version = model_info.latest_version.id
replicate_model_version = f'{model}:{model_version}'
text_input_key = self._get_text_input_key(model, model_version, client)
embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
texts)
@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
if 'model_version' not in credentials:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
try:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
if 'model_version' in credentials:
model_version = credentials['model_version']
else:
model_info = client.models.get(model)
model_version = model_info.latest_version.id
replicate_model_version = f'{model}:{model_version}'
text_input_key = self._get_text_input_key(model, model_version, client)
self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
['Hello worlds!'])

View File

@ -1,9 +1,23 @@
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
REPETITION_PENALTY = "repetition_penalty"
TOP_K = "top_k"
features = []
return super().get_customizable_model_schema(model, cred_with_endpoint)
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")),
ModelPropertyKey.MODE: cred_with_endpoint.get('mode'),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=float(cred_with_endpoint.get('temperature', 0.7)),
min=0,
max=2,
precision=2
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
default=float(cred_with_endpoint.get('top_p', 1)),
min=0,
max=1,
precision=2
),
ParameterRule(
name=TOP_K,
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(cred_with_endpoint.get('top_k', 50)),
min=-2147483647,
max=2147483647,
precision=0
),
ParameterRule(
name=REPETITION_PENALTY,
label=I18nObject(en_US="Repetition Penalty"),
type=ParameterType.FLOAT,
default=float(cred_with_endpoint.get('repetition_penalty', 1)),
min=-3.4,
max=3.4,
precision=1
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
default=512,
min=1,
max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)),
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('frequency_penalty', 0)),
min=-2,
max=2
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="Presence Penalty"),
type=ParameterType.FLOAT,
default=float(credentials.get('presence_penalty', 0)),
min=-2,
max=2
)
],
pricing=PriceConfig(
input=Decimal(cred_with_endpoint.get('input_price', 0)),
output=Decimal(cred_with_endpoint.get('output_price', 0)),
unit=Decimal(cred_with_endpoint.get('unit', 0)),
currency=cred_with_endpoint.get('currency', "USD")
),
)
if cred_with_endpoint['mode'] == 'chat':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif cred_with_endpoint['mode'] == 'completion':
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}")
return entity
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:

View File

@ -32,3 +32,8 @@ parameter_rules:
zh_Hans: SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回true 为增量返回false 为全量返回。
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
required: false
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8192

View File

@ -124,7 +124,7 @@ class MilvusVector(BaseVector):
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
@ -136,7 +136,7 @@ class MilvusVector(BaseVector):
if utility.has_collection(self._collection_name, using=alias):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]

View File

@ -0,0 +1,12 @@
from uuid import UUID
from numpy import ndarray
from sqlalchemy.orm import DeclarativeBase, Mapped
class CollectionORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
text: Mapped[str]
meta: Mapped[dict]
vector: Mapped[ndarray]

View File

@ -0,0 +1,224 @@
import logging
from typing import Any
from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, root_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PgvectoRSConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values['port']:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values['user']:
raise ValueError("config PGVECTO_RS_USER is required")
if not values['password']:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values['database']:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields = []
class _Table(CollectionORM):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True} # noqa: RUF012
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str] = mapped_column(String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(Vector(dim))
self._table = _Table
self._distance_op = "<=>"
def get_type(self) -> str:
return 'pgvecto-rs'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {self._collection_name} USING vectors(vector vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000
segment.max_sealed_segment_size = 30000000
[indexing.hnsw]
m=30
ef_construction=500
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
insert(self._table).values(
id=pk,
text=document.page_content,
meta=document.metadata,
vector=embedding,
),
)
pks.append(pk)
session.commit()
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
)
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.commit()
def delete(self) -> None:
with Session(self._client) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
)
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
with Session(self._client) as session:
stmt = (
select(
self._table,
self._table.vector.op(self._distance_op, return_type=Float)(
query_vector,
).label("distance"),
)
.limit(kwargs.get('top_k', 2))
.order_by("distance")
)
res = session.execute(stmt)
results = [(row[0], row[1]) for row in res]
# Organize results.
docs = []
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata['score'] = score
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if score > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# with Session(self._client) as session:
# select_statement = sql_text(
# f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
# )
# results = session.execute(select_statement).fetchall()
# if results:
# docs = []
# for result in results:
# doc = Document(page_content=result[0],
# metadata=result[1])
# docs.append(doc)
# return docs
return []

View File

@ -36,6 +36,8 @@ class QdrantConfig(BaseModel):
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
grpc_port: int = 6334
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
@ -51,7 +53,9 @@ class QdrantConfig(BaseModel):
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout,
'verify': self.endpoint.startswith('https')
'verify': self.endpoint.startswith('https'),
'grpc_port': self.grpc_port,
'prefer_grpc': self.prefer_grpc
}
@ -113,8 +117,7 @@ class QdrantVector(BaseVector):
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
field_schema=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,

View File

@ -1,16 +1,23 @@
import logging
from typing import Any
import uuid
from typing import Any, Optional
from pgvecto_rs.sdk import PGVectoRs, Record
from pydantic import BaseModel, root_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
class RelytConfig(BaseModel):
host: str
@ -36,16 +43,14 @@ class RelytConfig(BaseModel):
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, dim: int):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._client = PGVectoRs(
db_url=self._url,
collection_name=self._collection_name,
dimension=dim
)
self.client = create_engine(self._url)
self._fields = []
self._group_id = group_id
def get_type(self) -> str:
return 'relyt'
@ -54,6 +59,7 @@ class RelytVector(BaseVector):
index_params = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
@ -63,21 +69,21 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client._engine) as session:
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")
with Session(self.client) as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
id TEXT PRIMARY KEY,
document TEXT NOT NULL,
metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
CREATE INDEX {index_name}
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops)
ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000
@ -92,21 +98,62 @@ class RelytVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)]
pks = [str(r.id) for r in records]
self._client.insert(records)
return pks
from pgvecto_rs.sqlalchemy import Vector
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", Vector(len(embeddings[0]))),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete_by_ids(ids)
self.delete_by_uuids(ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client._engine) as session:
with Session(self.client) as session:
select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; "
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
)
result = session.execute(select_statement).fetchall()
if result:
@ -114,56 +161,140 @@ class RelytVector(BaseVector):
else:
return None
def delete_by_uuids(self, ids: list[str] = None):
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
from pgvecto_rs.sqlalchemy import Vector
if ids is None:
raise ValueError("No ids provided to delete.")
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", Vector(self.embedding_dimension)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
try:
with self.client.connect() as conn:
with conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e)) # noqa: T201
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete_by_ids(ids)
self.delete_by_uuids(ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
with Session(self._client._engine) as session:
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); "
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
result = session.execute(select_statement).fetchall()
if result:
ids = [item[0] for item in result]
self._client.delete_by_ids(ids)
self.delete_by_uuids(ids)
def delete(self) -> None:
with Session(self._client._engine) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}"))
with Session(self.client) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client._engine) as session:
with Session(self.client) as session:
select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
)
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from pgvecto_rs.sdk import filters
filter_condition = filters.meta_contains(kwargs.get('filter'))
results = self._client.search(
top_k=int(kwargs.get('top_k')),
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get('top_k')),
embedding=query_vector,
filter=filter_condition
filter=kwargs.get('filter')
)
# Organize results.
docs = []
for record, dis in results:
metadata = record.meta
metadata['score'] = dis
for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if dis > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
docs.append(doc)
if 1 - score > score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"
# Define the base query
sql_query = f"""
set vectors.enable_search_growing = on;
set vectors.enable_search_write = on;
SELECT document, metadata, embedding <-> :embedding as distance
FROM "{self._collection_name}"
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
# Set up the query parameters
embedding_str = ", ".join(format(x) for x in embedding)
embedding_str = "[" + embedding_str + "]"
params = {"embedding": embedding_str, "k": k}
# Execute the query and fetch the results
with self.client.connect() as conn:
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
Document(
page_content=result.document,
metadata=result.metadata,
),
result.distance,
)
for result in results
]
return documents_with_scores
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz/relyt doesn't support bm25 search
return []

View File

@ -27,6 +27,12 @@ class BaseVector(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError

View File

@ -86,7 +86,9 @@ class Vector:
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
)
)
elif vector_type == "milvus":
@ -126,7 +128,6 @@ class Vector:
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
dim = len(self._embeddings.embed_query("hello relyt"))
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
@ -136,6 +137,31 @@ class Vector:
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
),
group_id=self._dataset.id
)
elif vector_type == "pgvecto_rs":
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": 'pgvecto_rs',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
dim = len(self._embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
),
dim=dim
)
else:

View File

@ -29,8 +29,7 @@ class WordExtractor(BaseExtractor):
if r.status_code != 200:
raise ValueError(
"Check the url of your file; returned status code %s"
% r.status_code
f"Check the url of your file; returned status code {r.status_code}"
)
self.web_path = self.file_path
@ -38,7 +37,7 @@ class WordExtractor(BaseExtractor):
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError("File path %s is not a valid file or url" % self.file_path)
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def __del__(self) -> None:
if hasattr(self, "temp_file"):

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="111" height="111" viewBox="0 0 111 111" fill="none">
<text x="0" y="90" font-family="Verdana" font-size="85" fill="black">🔥</text>
</svg>

After

Width:  |  Height:  |  Size: 193 B

View File

@ -0,0 +1,23 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class FirecrawlProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the Crawl tool
CrawlTool().fork_tool_runtime(
meta={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
"url": "https://example.com",
"includes": '',
"excludes": '',
"limit": 1,
"onlyMainContent": True,
}
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,24 @@
identity:
author: Richards Tu
name: firecrawl
label:
en_US: Firecrawl
zh_CN: Firecrawl
description:
en_US: Firecrawl API integration for web crawling and scraping.
zh_CN: Firecrawl API 集成,用于网页爬取和数据抓取。
icon: icon.svg
credentials_for_provider:
firecrawl_api_key:
type: secret-input
required: true
label:
en_US: Firecrawl API Key
zh_CN: Firecrawl API 密钥
placeholder:
en_US: Please input your Firecrawl API key
zh_CN: 请输入您的 Firecrawl API 密钥
help:
en_US: Get your Firecrawl API key from your Firecrawl account settings.
zh_CN: 从您的 Firecrawl 账户设置中获取 Firecrawl API 密钥。
url: https://www.firecrawl.dev/account

View File

@ -0,0 +1,50 @@
from typing import Any, Union
from firecrawl import FirecrawlApp
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class CrawlTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# initialize the app object with the api key
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'])
options = {
'crawlerOptions': {
'excludes': tool_parameters.get('excludes', '').split(',') if tool_parameters.get('excludes') else [],
'includes': tool_parameters.get('includes', '').split(',') if tool_parameters.get('includes') else [],
'limit': tool_parameters.get('limit', 5)
},
'pageOptions': {
'onlyMainContent': tool_parameters.get('onlyMainContent', False)
}
}
# crawl the url
crawl_result = app.crawl_url(
url=tool_parameters['url'],
params=options,
wait_until_done=True,
)
# reformat crawl result
crawl_output = "**Crawl Result**\n\n"
try:
for result in crawl_result:
crawl_output += f"**- Title:** {result.get('metadata', {}).get('title', '')}\n"
crawl_output += f"**- Description:** {result.get('metadata', {}).get('description', '')}\n"
crawl_output += f"**- URL:** {result.get('metadata', {}).get('ogUrl', '')}\n\n"
crawl_output += f"**- Web Content:**\n{result.get('markdown', '')}\n\n"
crawl_output += "---\n\n"
except Exception as e:
crawl_output += f"An error occurred: {str(e)}\n"
crawl_output += f"**- Title:** {result.get('metadata', {}).get('title', '')}\n"
crawl_output += f"**- Description:** {result.get('metadata', {}).get('description','')}\n"
crawl_output += f"**- URL:** {result.get('metadata', {}).get('ogUrl', '')}\n\n"
crawl_output += f"**- Web Content:**\n{result.get('markdown', '')}\n\n"
crawl_output += "---\n\n"
return self.create_text_message(crawl_output)

View File

@ -0,0 +1,78 @@
identity:
name: crawl
author: Richards Tu
label:
en_US: Crawl
zh_Hans: 爬取
description:
human:
en_US: Extract data from a website by crawling through a URL.
zh_Hans: 通过URL从网站中提取数据。
llm: This tool initiates a web crawl to extract data from a specified URL. It allows configuring crawler options such as including or excluding URL patterns, generating alt text for images using LLMs (paid plan required), limiting the maximum number of pages to crawl, and returning only the main content of the page. The tool can return either a list of crawled documents or a list of URLs based on the provided options.
parameters:
- name: url
type: string
required: true
label:
en_US: URL to crawl
zh_Hans: 要爬取的URL
human_description:
en_US: The URL of the website to crawl and extract data from.
zh_Hans: 要爬取并提取数据的网站URL。
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
form: llm
- name: includes
type: string
required: false
label:
en_US: URL patterns to include
zh_Hans: 要包含的URL模式
human_description:
en_US: Specify URL patterns to include during the crawl. Only pages matching these patterns will be crawled, you can use ',' to separate multiple patterns.
zh_Hans: 指定爬取过程中要包含的URL模式。只有与这些模式匹配的页面才会被爬取。
form: form
default: ''
- name: excludes
type: string
required: false
label:
en_US: URL patterns to exclude
zh_Hans: 要排除的URL模式
human_description:
en_US: Specify URL patterns to exclude during the crawl. Pages matching these patterns will be skipped, you can use ',' to separate multiple patterns.
zh_Hans: 指定爬取过程中要排除的URL模式。匹配这些模式的页面将被跳过。
form: form
default: 'blog/*'
- name: limit
type: number
required: false
label:
en_US: Maximum number of pages to crawl
zh_Hans: 最大爬取页面数
human_description:
en_US: Specify the maximum number of pages to crawl. The crawler will stop after reaching this limit.
zh_Hans: 指定要爬取的最大页面数。爬虫将在达到此限制后停止。
form: form
min: 1
max: 20
default: 5
- name: onlyMainContent
type: boolean
required: false
label:
en_US: Only return the main content of the page
zh_Hans: 仅返回页面的主要内容
human_description:
en_US: If enabled, the crawler will only return the main content of the page, excluding headers, navigation, footers, etc.
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
form: form
options:
- value: true
label:
en_US: Yes
zh_Hans:
- value: false
label:
en_US: No
zh_Hans:
default: false

View File

@ -1,14 +1,14 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool
from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class Judge0CEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
SubmitCodeExecutionTaskTool().fork_tool_runtime(
ExecuteCodeTool().fork_tool_runtime(
meta={
"credentials": credentials,
}

View File

@ -0,0 +1,59 @@
import json
from typing import Any, Union
import requests
from httpx import post
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class ExecuteCodeTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key = self.runtime.credentials['X-RapidAPI-Key']
url = "https://judge0-ce.p.rapidapi.com/submissions"
querystring = {"base64_encoded": "false", "fields": "*"}
headers = {
"Content-Type": "application/json",
"X-RapidAPI-Key": api_key,
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
}
payload = {
"language_id": tool_parameters['language_id'],
"source_code": tool_parameters['source_code'],
"stdin": tool_parameters.get('stdin', ''),
"expected_output": tool_parameters.get('expected_output', ''),
"additional_files": tool_parameters.get('additional_files', ''),
}
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
if response.status_code != 201:
raise Exception(response.text)
token = response.json()['token']
url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}"
headers = {
"X-RapidAPI-Key": api_key
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n"
f"stderr: {result.get('stderr', '')}\n"
f"compile_output: {result.get('compile_output', '')}\n"
f"message: {result.get('message', '')}\n"
f"status: {result['status']['description']}\n"
f"time: {result.get('time', '')} seconds\n"
f"memory: {result.get('memory', '')} bytes")
else:
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")

View File

@ -2,13 +2,13 @@ identity:
name: submitCodeExecutionTask
author: Richards Tu
label:
en_US: Submit Code Execution Task
zh_Hans: 提交代码执行任务
en_US: Submit Code Execution Task to Judge0 CE and get execution result.
zh_Hans: 提交代码执行任务到 Judge0 CE 并获取执行结果。
description:
human:
en_US: A tool for submitting code execution task to Judge0 CE.
zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。
llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission.
en_US: A tool for executing code and getting the result.
zh_Hans: 一个用于执行代码并获取结果的工具。
llm: This tool is used for executing code and getting the result.
parameters:
- name: source_code
type: string

View File

@ -1,37 +0,0 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GetExecutionResultTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key = self.runtime.credentials['X-RapidAPI-Key']
url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}"
headers = {
"X-RapidAPI-Key": api_key
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
return self.create_text_message(text=f"Submission details:\n"
f"stdout: {result.get('stdout', '')}\n"
f"stderr: {result.get('stderr', '')}\n"
f"compile_output: {result.get('compile_output', '')}\n"
f"message: {result.get('message', '')}\n"
f"status: {result['status']['description']}\n"
f"time: {result.get('time', '')} seconds\n"
f"memory: {result.get('memory', '')} bytes")
else:
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")

View File

@ -1,23 +0,0 @@
identity:
name: getExecutionResult
author: Richards Tu
label:
en_US: Get Execution Result
zh_Hans: 获取执行结果
description:
human:
en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。
llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
parameters:
- name: token
type: string
required: true
label:
en_US: Token
zh_Hans: 令牌
human_description:
en_US: The submission's unique token.
zh_Hans: 提交的唯一令牌。
llm_description: The submission's unique token. MUST get from submitCodeExecution.
form: llm

View File

@ -1,49 +0,0 @@
import json
from typing import Any, Union
from httpx import post
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class SubmitCodeExecutionTaskTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key = self.runtime.credentials['X-RapidAPI-Key']
source_code = tool_parameters['source_code']
language_id = tool_parameters['language_id']
stdin = tool_parameters.get('stdin', '')
expected_output = tool_parameters.get('expected_output', '')
additional_files = tool_parameters.get('additional_files', '')
url = "https://judge0-ce.p.rapidapi.com/submissions"
querystring = {"base64_encoded": "false", "fields": "*"}
payload = {
"language_id": language_id,
"source_code": source_code,
"stdin": stdin,
"expected_output": expected_output,
"additional_files": additional_files,
}
headers = {
"content-type": "application/json",
"Content-Type": "application/json",
"X-RapidAPI-Key": api_key,
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
}
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
if response.status_code != 201:
raise Exception(response.text)
token = response.json()['token']
return self.create_text_message(text=token)

View File

@ -42,20 +42,19 @@ def get_url(url: str, user_agent: str = None) -> str:
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if head_response.status_code != 200:
return "URL returned status code {}.".format(head_response.status_code)
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# check content-type
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():

View File

@ -141,10 +141,10 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, str):
raise ValueError(f"{variable} in output form must be a string")
raise ValueError(f"Output variable `{variable}` must be a string")
if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters')
return value.replace('\x00', '')
@ -156,15 +156,15 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, int | float):
raise ValueError(f"{variable} in output form must be a number")
raise ValueError(f"Output variable `{variable}` must be a number")
if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'{variable} in input form is out of range.')
raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.')
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split('.')[1]) > MAX_PRECISION:
raise ValueError(f'{variable} in output form has too high precision.')
raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.')
return value
@ -271,7 +271,7 @@ class CodeNode(BaseNode):
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters.'
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.'
)
transformed_result[output_name] = [
@ -290,7 +290,7 @@ class CodeNode(BaseNode):
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters.'
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.'
)
transformed_result[output_name] = [
@ -309,7 +309,7 @@ class CodeNode(BaseNode):
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
f'{prefix}{dot}{output_name} in output form must be less than {MAX_OBJECT_ARRAY_LENGTH} characters.'
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.'
)
for i, value in enumerate(result[output_name]):

View File

@ -36,6 +36,49 @@ class EndNode(BaseNode):
outputs=outputs
)
@classmethod
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
"""
Extract generate nodes
:param graph: graph
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@classmethod
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
"""
Extract generate nodes from node data
:param graph: graph
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes')
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs
generate_nodes = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_mapping:
node = node_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
generate_nodes.append(node_id)
# remove duplicates
generate_nodes = list(set(generate_nodes))
return generate_nodes
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""

View File

@ -35,9 +35,15 @@ class HttpRequestNodeData(BaseNodeData):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
data: Union[None, str]
class Timeout(BaseModel):
connect: int
read: int
write: int
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
url: str
authorization: Authorization
headers: str
params: str
body: Optional[Body]
body: Optional[Body]
timeout: Optional[Timeout]

View File

@ -13,7 +13,6 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
READABLE_MAX_BINARY_SIZE = '10MB'
MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
@ -137,14 +136,16 @@ class HttpExecutor:
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
timeout: HttpRequestNodeData.Timeout
def __init__(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None):
"""
init
"""
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.timeout = timeout
self.params = {}
self.headers = {}
self.body = None
@ -307,7 +308,7 @@ class HttpExecutor:
'url': self.server_url,
'headers': headers,
'params': self.params,
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
'follow_redirects': True
}

View File

@ -1,4 +1,5 @@
import logging
import os
from mimetypes import guess_extension
from os import path
from typing import cast
@ -12,18 +13,49 @@ from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus
MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300'))
MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT),
read=min(60, MAX_READ_TIMEOUT),
write=min(20, MAX_WRITE_TIMEOUT))
class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
node_type = NodeType.HTTP_REQUEST
@classmethod
def get_default_config(cls) -> dict:
return {
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
"body": {
"type": "none"
},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.dict(),
"max_connect_timeout": MAX_CONNECT_TIMEOUT,
"max_read_timeout": MAX_READ_TIMEOUT,
"max_write_timeout": MAX_WRITE_TIMEOUT,
}
},
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool)
http_executor = HttpExecutor(node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=variable_pool)
# invoke http executor
response = http_executor.invoke()
@ -38,7 +70,7 @@ class HttpRequestNode(BaseNode):
error=str(e),
process_data=process_data
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
@ -54,6 +86,16 @@ class HttpRequestNode(BaseNode):
}
)
def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeData.Timeout:
timeout = node_data.timeout
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT)
timeout.read = min(timeout.read, MAX_READ_TIMEOUT)
timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT)
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
"""
@ -62,7 +104,7 @@ class HttpRequestNode(BaseNode):
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data)
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
variable_selectors = http_executor.variable_selectors
@ -84,7 +126,7 @@ class HttpRequestNode(BaseNode):
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype:
# extract filename from url
filename = path.basename(url)

View File

@ -79,7 +79,6 @@ class QuestionClassifierNode(LLMNode):
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'topics': categories[0] if categories else ''
}
outputs = {
'class_name': categories[0] if categories else ''

View File

@ -1,80 +1,42 @@
import os
import shutil
from collections.abc import Generator
from contextlib import closing
from datetime import datetime, timedelta, timezone
from typing import Union
import boto3
import oss2 as aliyun_s3
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from botocore.client import Config
from botocore.exceptions import ClientError
from flask import Flask
from extensions.storage.aliyun_storage import AliyunStorage
from extensions.storage.azure_storage import AzureStorage
from extensions.storage.google_storage import GoogleStorage
from extensions.storage.local_storage import LocalStorage
from extensions.storage.s3_storage import S3Storage
class Storage:
def __init__(self):
self.storage_type = None
self.bucket_name = None
self.client = None
self.folder = None
self.storage_runner = None
def init_app(self, app: Flask):
self.storage_type = app.config.get('STORAGE_TYPE')
if self.storage_type == 's3':
self.bucket_name = app.config.get('S3_BUCKET_NAME')
self.client = boto3.client(
's3',
aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
endpoint_url=app.config.get('S3_ENDPOINT'),
region_name=app.config.get('S3_REGION'),
config=Config(s3={'addressing_style': app.config.get('S3_ADDRESS_STYLE')})
storage_type = app.config.get('STORAGE_TYPE')
if storage_type == 's3':
self.storage_runner = S3Storage(
app=app
)
elif self.storage_type == 'azure-blob':
self.bucket_name = app.config.get('AZURE_BLOB_CONTAINER_NAME')
sas_token = generate_account_sas(
account_name=app.config.get('AZURE_BLOB_ACCOUNT_NAME'),
account_key=app.config.get('AZURE_BLOB_ACCOUNT_KEY'),
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
elif storage_type == 'azure-blob':
self.storage_runner = AzureStorage(
app=app
)
self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'),
credential=sas_token)
elif self.storage_type == 'aliyun-oss':
self.bucket_name = app.config.get('ALIYUN_OSS_BUCKET_NAME')
self.client = aliyun_s3.Bucket(
aliyun_s3.Auth(app.config.get('ALIYUN_OSS_ACCESS_KEY'), app.config.get('ALIYUN_OSS_SECRET_KEY')),
app.config.get('ALIYUN_OSS_ENDPOINT'),
self.bucket_name,
connect_timeout=30
elif storage_type == 'aliyun-oss':
self.storage_runner = AliyunStorage(
app=app
)
elif storage_type == 'google-storage':
self.storage_runner = GoogleStorage(
app=app
)
else:
self.folder = app.config.get('STORAGE_LOCAL_PATH')
if not os.path.isabs(self.folder):
self.folder = os.path.join(app.root_path, self.folder)
self.storage_runner = LocalStorage(app=app)
def save(self, filename, data):
if self.storage_type == 's3':
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
elif self.storage_type == 'azure-blob':
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
elif self.storage_type == 'aliyun-oss':
self.client.put_object(filename, data)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True)
with open(os.path.join(os.getcwd(), filename), "wb") as f:
f.write(data)
self.storage_runner.save(filename, data)
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
if stream:
@ -83,131 +45,19 @@ class Storage:
return self.load_once(filename)
def load_once(self, filename: str) -> bytes:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
elif self.storage_type == 'azure-blob':
blob = self.client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
elif self.storage_type == 'aliyun-oss':
with closing(self.client.get_object(filename)) as obj:
data = obj.read()
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
data = f.read()
return data
return self.storage_runner.load_once(filename)
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
for chunk in response['Body'].iter_chunks():
yield chunk
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
elif self.storage_type == 'azure-blob':
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
with closing(blob.download_blob()) as blob_stream:
while chunk := blob_stream.readall(4096):
yield chunk
elif self.storage_type == 'aliyun-oss':
with closing(self.client.get_object(filename)) as obj:
while chunk := obj.read(4096):
yield chunk
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
while chunk := f.read(4096): # Read in chunks of 4KB
yield chunk
return generate()
return self.storage_runner.load_stream(filename)
def download(self, filename, target_filepath):
if self.storage_type == 's3':
with closing(self.client) as client:
client.download_file(self.bucket_name, filename, target_filepath)
elif self.storage_type == 'azure-blob':
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
elif self.storage_type == 'aliyun-oss':
self.client.get_object_to_file(filename, target_filepath)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
shutil.copyfile(filename, target_filepath)
self.storage_runner.download(filename, target_filepath)
def exists(self, filename):
if self.storage_type == 's3':
with closing(self.client) as client:
try:
client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except:
return False
elif self.storage_type == 'azure-blob':
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
elif self.storage_type == 'aliyun-oss':
return self.client.object_exists(filename)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
return os.path.exists(filename)
return self.storage_runner.exists(filename)
def delete(self, filename):
if self.storage_type == 's3':
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
elif self.storage_type == 'azure-blob':
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)
elif self.storage_type == 'aliyun-oss':
self.client.delete_object(filename)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if os.path.exists(filename):
os.remove(filename)
return self.storage_runner.delete(filename)
storage = Storage()

View File

@ -0,0 +1,48 @@
from collections.abc import Generator
from contextlib import closing
import oss2 as aliyun_s3
from flask import Flask
from extensions.storage.base_storage import BaseStorage
class AliyunStorage(BaseStorage):
"""Implementation for aliyun storage.
"""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME')
self.client = aliyun_s3.Bucket(
aliyun_s3.Auth(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')),
app_config.get('ALIYUN_OSS_ENDPOINT'),
self.bucket_name,
connect_timeout=30
)
def save(self, filename, data):
self.client.put_object(filename, data)
def load_once(self, filename: str) -> bytes:
with closing(self.client.get_object(filename)) as obj:
data = obj.read()
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
with closing(self.client.get_object(filename)) as obj:
while chunk := obj.read(4096):
yield chunk
return generate()
def download(self, filename, target_filepath):
self.client.get_object_to_file(filename, target_filepath)
def exists(self, filename):
return self.client.object_exists(filename)
def delete(self, filename):
self.client.delete_object(filename)

View File

@ -0,0 +1,58 @@
from collections.abc import Generator
from contextlib import closing
from datetime import datetime, timedelta, timezone
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from flask import Flask
from extensions.storage.base_storage import BaseStorage
class AzureStorage(BaseStorage):
"""Implementation for azure storage.
"""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('AZURE_STORAGE_CONTAINER_NAME')
sas_token = generate_account_sas(
account_name=app_config.get('AZURE_BLOB_ACCOUNT_NAME'),
account_key=app_config.get('AZURE_BLOB_ACCOUNT_KEY'),
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
)
self.client = BlobServiceClient(account_url=app_config.get('AZURE_BLOB_ACCOUNT_URL'),
credential=sas_token)
def save(self, filename, data):
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
def load_once(self, filename: str) -> bytes:
blob = self.client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
with closing(blob.download_blob()) as blob_stream:
while chunk := blob_stream.readall(4096):
yield chunk
return generate()
def download(self, filename, target_filepath):
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
def exists(self, filename):
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
def delete(self, filename):
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)

View File

@ -0,0 +1,38 @@
"""Abstract interface for file storage implementations."""
from abc import ABC, abstractmethod
from collections.abc import Generator
from flask import Flask
class BaseStorage(ABC):
"""Interface for file storage.
"""
app = None
def __init__(self, app: Flask):
self.app = app
@abstractmethod
def save(self, filename, data):
raise NotImplementedError
@abstractmethod
def load_once(self, filename: str) -> bytes:
raise NotImplementedError
@abstractmethod
def load_stream(self, filename: str) -> Generator:
raise NotImplementedError
@abstractmethod
def download(self, filename, target_filepath):
raise NotImplementedError
@abstractmethod
def exists(self, filename):
raise NotImplementedError
@abstractmethod
def delete(self, filename):
raise NotImplementedError

View File

@ -0,0 +1,56 @@
import base64
from collections.abc import Generator
from contextlib import closing
from flask import Flask
from google.cloud import storage as GoogleCloudStorage
from extensions.storage.base_storage import BaseStorage
class GoogleStorage(BaseStorage):
"""Implementation for google storage.
"""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME')
service_account_json = base64.b64decode(app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')).decode(
'utf-8')
self.client = GoogleCloudStorage.Client().from_service_account_json(service_account_json)
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
blob.upload_from_file(data)
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
data = blob.download_as_bytes()
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
with closing(blob.open(mode='rb')) as blob_stream:
while chunk := blob_stream.read(4096):
yield chunk
return generate()
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
def exists(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
return blob.exists()
def delete(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
bucket.delete_blob(filename)

View File

@ -0,0 +1,88 @@
import os
import shutil
from collections.abc import Generator
from flask import Flask
from extensions.storage.base_storage import BaseStorage
class LocalStorage(BaseStorage):
"""Implementation for local storage.
"""
def __init__(self, app: Flask):
super().__init__(app)
folder = self.app.config.get('STORAGE_LOCAL_PATH')
if not os.path.isabs(folder):
folder = os.path.join(app.root_path, folder)
self.folder = folder
def save(self, filename, data):
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True)
with open(os.path.join(os.getcwd(), filename), "wb") as f:
f.write(data)
def load_once(self, filename: str) -> bytes:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
data = f.read()
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
while chunk := f.read(4096): # Read in chunks of 4KB
yield chunk
return generate()
def download(self, filename, target_filepath):
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
shutil.copyfile(filename, target_filepath)
def exists(self, filename):
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
return os.path.exists(filename)
def delete(self, filename):
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if os.path.exists(filename):
os.remove(filename)

View File

@ -0,0 +1,68 @@
from collections.abc import Generator
from contextlib import closing
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from flask import Flask
from extensions.storage.base_storage import BaseStorage
class S3Storage(BaseStorage):
"""Implementation for s3 storage.
"""
def __init__(self, app: Flask):
super().__init__(app)
app_config = self.app.config
self.bucket_name = app_config.get('S3_BUCKET_NAME')
self.client = boto3.client(
's3',
aws_secret_access_key=app_config.get('S3_SECRET_KEY'),
aws_access_key_id=app_config.get('S3_ACCESS_KEY'),
endpoint_url=app_config.get('S3_ENDPOINT'),
region_name=app_config.get('S3_REGION'),
config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')})
)
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
def load_once(self, filename: str) -> bytes:
try:
with closing(self.client) as client:
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response['Body'].iter_chunks()
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
return generate()
def download(self, filename, target_filepath):
with closing(self.client) as client:
client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename):
with closing(self.client) as client:
try:
client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except:
return False
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -153,23 +153,12 @@ class NotionOAuth(OAuthDataSource):
# get page detail
for page_result in page_results:
page_id = page_result['id']
if 'Name' in page_result['properties']:
if len(page_result['properties']['Name']['title']) > 0:
page_name = page_result['properties']['Name']['title'][0]['plain_text']
else:
page_name = 'Untitled'
elif 'title' in page_result['properties']:
if len(page_result['properties']['title']['title']) > 0:
page_name = page_result['properties']['title']['title'][0]['plain_text']
else:
page_name = 'Untitled'
elif 'Title' in page_result['properties']:
if len(page_result['properties']['Title']['title']) > 0:
page_name = page_result['properties']['Title']['title'][0]['plain_text']
else:
page_name = 'Untitled'
else:
page_name = 'Untitled'
page_name = 'Untitled'
for key in ['Name', 'title', 'Title', 'Page']:
if key in page_result['properties']:
if len(page_result['properties'][key].get('title', [])) > 0:
page_name = page_result['properties'][key]['title'][0]['plain_text']
break
page_icon = page_result['icon']
if page_icon:
icon_type = page_icon['type']

View File

@ -6,6 +6,7 @@ Create Date: ${create_date}
"""
from alembic import op
import models as models
import sqlalchemy as sa
${imports if imports else ""}

View File

@ -1,5 +1,8 @@
from enum import Enum
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
class CreatedByRole(Enum):
"""
@ -42,3 +45,27 @@ class CreatedFrom(Enum):
if role.value == value:
return role
raise ValueError(f'invalid createdFrom value {value}')
class StringUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
if value is None:
return value
return str(value)

View File

@ -2,9 +2,9 @@ import enum
import json
from flask_login import UserMixin
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models import StringUUID
class AccountStatus(str, enum.Enum):
@ -22,7 +22,7 @@ class Account(UserMixin, db.Model):
db.Index('account_email_idx', 'email')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=True)
@ -128,7 +128,7 @@ class Tenant(db.Model):
db.PrimaryKeyConstraint('id', name='tenant_pkey'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False)
encrypt_public_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
@ -168,12 +168,12 @@ class TenantAccountJoin(db.Model):
db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
account_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
role = db.Column(db.String(16), nullable=False, server_default='normal')
invited_by = db.Column(UUID, nullable=True)
invited_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -186,8 +186,8 @@ class AccountIntegrate(db.Model):
db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
account_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
account_id = db.Column(StringUUID, nullable=False)
provider = db.Column(db.String(16), nullable=False)
open_id = db.Column(db.String(255), nullable=False)
encrypted_token = db.Column(db.String(255), nullable=False)
@ -208,7 +208,7 @@ class InvitationCode(db.Model):
code = db.Column(db.String(32), nullable=False)
status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
used_at = db.Column(db.DateTime)
used_by_tenant_id = db.Column(UUID)
used_by_account_id = db.Column(UUID)
used_by_tenant_id = db.Column(StringUUID)
used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -1,8 +1,7 @@
import enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models import StringUUID
class APIBasedExtensionPoint(enum.Enum):
@ -19,8 +18,8 @@ class APIBasedExtension(db.Model):
db.Index('api_based_extension_tenant_idx', 'tenant_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
api_endpoint = db.Column(db.String(255), nullable=False)
api_key = db.Column(db.Text, nullable=False)

View File

@ -4,10 +4,11 @@ import pickle
from json import JSONDecodeError
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import StringUUID
from models.account import Account
from models.model import App, Tag, TagBinding, UploadFile
@ -22,8 +23,8 @@ class Dataset(db.Model):
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
provider = db.Column(db.String(255), nullable=False,
@ -33,15 +34,15 @@ class Dataset(db.Model):
data_source_type = db.Column(db.String(255))
indexing_technique = db.Column(db.String(255), nullable=True)
index_struct = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property
@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model):
db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
mode = db.Column(db.String(255), nullable=False,
server_default=db.text("'automatic'::character varying"))
rules = db.Column(db.Text, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -197,19 +198,19 @@ class Document(db.Model):
)
# initial fields
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False)
data_source_info = db.Column(db.Text, nullable=True)
dataset_process_rule_id = db.Column(UUID, nullable=True)
dataset_process_rule_id = db.Column(StringUUID, nullable=True)
batch = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_from = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_api_request_id = db.Column(UUID, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_api_request_id = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -234,7 +235,7 @@ class Document(db.Model):
# pause
is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
paused_by = db.Column(UUID, nullable=True)
paused_by = db.Column(StringUUID, nullable=True)
paused_at = db.Column(db.DateTime, nullable=True)
# error
@ -247,11 +248,11 @@ class Document(db.Model):
enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True)
disabled_by = db.Column(StringUUID, nullable=True)
archived = db.Column(db.Boolean, nullable=False,
server_default=db.text('false'))
archived_reason = db.Column(db.String(255), nullable=True)
archived_by = db.Column(UUID, nullable=True)
archived_by = db.Column(StringUUID, nullable=True)
archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -356,11 +357,11 @@ class DocumentSegment(db.Model):
)
# initial fields
id = db.Column(UUID, nullable=False,
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
document_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True)
@ -377,13 +378,13 @@ class DocumentSegment(db.Model):
enabled = db.Column(db.Boolean, nullable=False,
server_default=db.text('true'))
disabled_at = db.Column(db.DateTime, nullable=True)
disabled_by = db.Column(UUID, nullable=True)
disabled_by = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False,
server_default=db.text("'waiting'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID, nullable=True)
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
indexing_at = db.Column(db.DateTime, nullable=True)
@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model):
db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
)
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
@ -438,13 +439,13 @@ class DatasetQuery(db.Model):
db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(StringUUID, nullable=False)
content = db.Column(db.Text, nullable=False)
source = db.Column(db.String(255), nullable=False)
source_app_id = db.Column(UUID, nullable=True)
source_app_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model):
db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False, unique=True)
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(StringUUID, nullable=False, unique=True)
keyword_table = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False,
server_default=db.text("'database'::character varying"))
@ -501,7 +502,7 @@ class Embedding(db.Model):
db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
model_name = db.Column(db.String(40), nullable=False,
server_default=db.text("'text-embedding-ada-002'::character varying"))
hash = db.Column(db.String(64), nullable=False)
@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model):
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)

View File

@ -7,13 +7,13 @@ from typing import Optional
from flask import current_app, request
from flask_login import UserMixin
from sqlalchemy import Float, text
from sqlalchemy.dialects.postgresql import UUID
from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from . import StringUUID
from .account import Account, Tenant
@ -56,15 +56,15 @@ class App(db.Model):
db.Index('app_tenant_id_idx', 'tenant_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
mode = db.Column(db.String(255), nullable=False)
icon = db.Column(db.String(255))
icon_background = db.Column(db.String(255))
app_model_config_id = db.Column(UUID, nullable=True)
workflow_id = db.Column(UUID, nullable=True)
app_model_config_id = db.Column(StringUUID, nullable=True)
workflow_id = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
enable_site = db.Column(db.Boolean, nullable=False)
enable_api = db.Column(db.Boolean, nullable=False)
@ -207,8 +207,8 @@ class AppModelConfig(db.Model):
db.Index('app_app_id_idx', 'app_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True)
configs = db.Column(db.JSON, nullable=True)
@ -430,8 +430,8 @@ class RecommendedApp(db.Model):
db.Index('recommended_app_is_listed_idx', 'is_listed', 'language')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
description = db.Column(db.JSON, nullable=False)
copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False)
@ -458,10 +458,10 @@ class InstalledApp(db.Model):
db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=False)
app_owner_tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
app_owner_tenant_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False, default=0)
is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
last_used_at = db.Column(db.DateTime, nullable=True)
@ -486,9 +486,9 @@ class Conversation(db.Model):
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
app_model_config_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text)
model_id = db.Column(db.String(255), nullable=True)
@ -502,10 +502,10 @@ class Conversation(db.Model):
status = db.Column(db.String(255), nullable=False)
invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID)
from_account_id = db.Column(UUID)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
read_at = db.Column(db.DateTime)
read_account_id = db.Column(UUID)
read_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -626,12 +626,12 @@ class Message(db.Model):
db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
inputs = db.Column(db.JSON)
query = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False)
@ -650,12 +650,12 @@ class Message(db.Model):
message_metadata = db.Column(db.Text)
invoke_from = db.Column(db.String(255), nullable=True)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID)
from_account_id = db.Column(UUID)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
workflow_run_id = db.Column(UUID)
workflow_run_id = db.Column(StringUUID)
@property
def re_sign_file_url_answer(self) -> str:
@ -846,15 +846,15 @@ class MessageFeedback(db.Model):
db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, nullable=False)
message_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
message_id = db.Column(StringUUID, nullable=False)
rating = db.Column(db.String(255), nullable=False)
content = db.Column(db.Text)
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(UUID)
from_account_id = db.Column(UUID)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -872,15 +872,15 @@ class MessageFile(db.Model):
db.Index('message_file_created_by_idx', 'created_by')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True)
belongs_to = db.Column(db.String(255), nullable=True)
upload_file_id = db.Column(UUID, nullable=True)
upload_file_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -893,14 +893,14 @@ class MessageAnnotation(db.Model):
db.Index('message_annotation_message_idx', 'message_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
message_id = db.Column(StringUUID, nullable=True)
question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
account_id = db.Column(UUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -925,15 +925,15 @@ class AppAnnotationHitHistory(db.Model):
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
annotation_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(UUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
score = db.Column(Float, nullable=False, server_default=db.text('0'))
message_id = db.Column(UUID, nullable=False)
message_id = db.Column(StringUUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False)
annotation_content = db.Column(db.Text, nullable=False)
@ -957,13 +957,13 @@ class AppAnnotationSetting(db.Model):
db.Index('app_annotation_settings_app_idx', 'app_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
collection_binding_id = db.Column(UUID, nullable=False)
created_user_id = db.Column(UUID, nullable=False)
collection_binding_id = db.Column(StringUUID, nullable=False)
created_user_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_user_id = db.Column(UUID, nullable=False)
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
@ -995,9 +995,9 @@ class OperationLog(db.Model):
db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
account_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
action = db.Column(db.String(255), nullable=False)
content = db.Column(db.JSON)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -1013,9 +1013,9 @@ class EndUser(UserMixin, db.Model):
db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True)
name = db.Column(db.String(255))
@ -1033,8 +1033,8 @@ class Site(db.Model):
db.Index('site_code_idx', 'code', 'status')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
title = db.Column(db.String(255), nullable=False)
icon = db.Column(db.String(255))
icon_background = db.Column(db.String(255))
@ -1074,9 +1074,9 @@ class ApiToken(db.Model):
db.Index('api_token_tenant_idx', 'tenant_id', 'type')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=True)
tenant_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=True)
tenant_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(16), nullable=False)
token = db.Column(db.String(255), nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True)
@ -1099,8 +1099,8 @@ class UploadFile(db.Model):
db.Index('upload_file_tenant_idx', 'tenant_id')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
storage_type = db.Column(db.String(255), nullable=False)
key = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
@ -1108,10 +1108,10 @@ class UploadFile(db.Model):
extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
used_by = db.Column(UUID, nullable=True)
used_by = db.Column(StringUUID, nullable=True)
used_at = db.Column(db.DateTime, nullable=True)
hash = db.Column(db.String(255), nullable=True)
@ -1123,9 +1123,9 @@ class ApiRequest(db.Model):
db.Index('api_request_token_idx', 'tenant_id', 'api_token_id')
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
api_token_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
api_token_id = db.Column(StringUUID, nullable=False)
path = db.Column(db.String(255), nullable=False)
request = db.Column(db.Text, nullable=True)
response = db.Column(db.Text, nullable=True)
@ -1140,8 +1140,8 @@ class MessageChain(db.Model):
db.Index('message_chain_message_id_idx', 'message_id')
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
input = db.Column(db.Text, nullable=True)
output = db.Column(db.Text, nullable=True)
@ -1156,9 +1156,9 @@ class MessageAgentThought(db.Model):
db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'),
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
message_chain_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False)
message_chain_id = db.Column(StringUUID, nullable=True)
position = db.Column(db.Integer, nullable=False)
thought = db.Column(db.Text, nullable=True)
tool = db.Column(db.Text, nullable=True)
@ -1166,7 +1166,7 @@ class MessageAgentThought(db.Model):
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = db.Column(db.Text, nullable=True)
observation = db.Column(db.Text, nullable=True)
# plugin_id = db.Column(UUID, nullable=True) ## for future design
# plugin_id = db.Column(StringUUID, nullable=True) ## for future design
tool_process_data = db.Column(db.Text, nullable=True)
message = db.Column(db.Text, nullable=True)
message_token = db.Column(db.Integer, nullable=True)
@ -1182,7 +1182,7 @@ class MessageAgentThought(db.Model):
currency = db.Column(db.String, nullable=True)
latency = db.Column(db.Float, nullable=True)
created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
@ -1273,15 +1273,15 @@ class DatasetRetrieverResource(db.Model):
db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
dataset_name = db.Column(db.Text, nullable=False)
document_id = db.Column(UUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
document_name = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.Text, nullable=False)
segment_id = db.Column(UUID, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
score = db.Column(db.Float, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=True)
@ -1289,7 +1289,7 @@ class DatasetRetrieverResource(db.Model):
segment_position = db.Column(db.Integer, nullable=True)
index_node_hash = db.Column(db.Text, nullable=True)
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@ -1303,11 +1303,11 @@ class Tag(db.Model):
TAG_TYPE_LIST = ['knowledge', 'app']
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(16), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -1319,9 +1319,9 @@ class TagBinding(db.Model):
db.Index('tag_bind_tag_id_idx', 'tag_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
tag_id = db.Column(UUID, nullable=True)
target_id = db.Column(UUID, nullable=True)
created_by = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=True)
tag_id = db.Column(StringUUID, nullable=True)
target_id = db.Column(StringUUID, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -1,8 +1,7 @@
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models import StringUUID
class ProviderType(Enum):
@ -46,8 +45,8 @@ class Provider(db.Model):
db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True)
@ -93,8 +92,8 @@ class ProviderModel(db.Model):
db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
@ -111,8 +110,8 @@ class TenantDefaultModel(db.Model):
db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
@ -127,8 +126,8 @@ class TenantPreferredModelProvider(db.Model):
db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -142,10 +141,10 @@ class ProviderOrder(db.Model):
db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
account_id = db.Column(UUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191))
transaction_id = db.Column(db.String(191))

View File

@ -1,6 +1,7 @@
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from models import StringUUID
class DataSourceBinding(db.Model):
@ -11,8 +12,8 @@ class DataSourceBinding(db.Model):
db.Index('source_info_idx', "source_info", postgresql_using='gin')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
access_token = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
source_info = db.Column(JSONB, nullable=False)

View File

@ -1,9 +1,8 @@
import json
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models import StringUUID
class ToolProviderName(Enum):
@ -24,8 +23,8 @@ class ToolProvider(db.Model):
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False)
encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))

View File

@ -1,12 +1,12 @@
import json
from sqlalchemy import ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType
from extensions.ext_database import db
from models import StringUUID
from models.model import Account, App, Tenant
@ -22,11 +22,11 @@ class BuiltinToolProvider(db.Model):
)
# id of the tool provider
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the tenant
tenant_id = db.Column(UUID, nullable=True)
tenant_id = db.Column(StringUUID, nullable=True)
# who created this tool provider
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# name of the tool provider
provider = db.Column(db.String(40), nullable=False)
# credential of the tool provider
@ -49,11 +49,11 @@ class PublishedAppTool(db.Model):
)
# id of the tool provider
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# id of the app
app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False)
app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
# who published this tool
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
@ -87,7 +87,7 @@ class ApiToolProvider(db.Model):
db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# name of the api provider
name = db.Column(db.String(40), nullable=False)
# icon
@ -96,9 +96,9 @@ class ApiToolProvider(db.Model):
schema = db.Column(db.Text, nullable=False)
schema_type_str = db.Column(db.String(40), nullable=False)
# who created this tool
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
# description of the provider
description = db.Column(db.Text, nullable=False)
# json format tools
@ -140,11 +140,11 @@ class ToolModelInvoke(db.Model):
db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# who invoke this tool
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
# provider
provider = db.Column(db.String(40), nullable=False)
# type
@ -180,13 +180,13 @@ class ToolConversationVariables(db.Model):
db.Index('conversation_id_idx', 'conversation_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
# conversation id
conversation_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
# variables pool
variables_str = db.Column(db.Text, nullable=False)
@ -208,13 +208,13 @@ class ToolFile(db.Model):
db.Index('tool_file_conversation_id_idx', 'conversation_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
# conversation user id
user_id = db.Column(UUID, nullable=False)
user_id = db.Column(StringUUID, nullable=False)
# tenant id
tenant_id = db.Column(UUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
# conversation id
conversation_id = db.Column(UUID, nullable=True)
conversation_id = db.Column(StringUUID, nullable=True)
# file key
file_key = db.Column(db.String(255), nullable=False)
# mime type

View File

@ -1,6 +1,6 @@
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
from models import StringUUID
from models.model import Message
@ -11,11 +11,11 @@ class SavedMessage(db.Model):
db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
message_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
message_id = db.Column(StringUUID, nullable=False)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
@ -30,9 +30,9 @@ class PinnedConversation(db.Model):
db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(UUID, nullable=False)
conversation_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -2,10 +2,9 @@ import json
from enum import Enum
from typing import Optional, Union
from sqlalchemy.dialects.postgresql import UUID
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models import StringUUID
from models.account import Account
@ -102,16 +101,16 @@ class Workflow(db.Model):
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text)
features = db.Column(db.Text)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID)
updated_by = db.Column(StringUUID)
updated_at = db.Column(db.DateTime)
@property
@ -245,11 +244,11 @@ class WorkflowRun(db.Model):
db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
sequence_number = db.Column(db.Integer, nullable=False)
workflow_id = db.Column(UUID, nullable=False)
workflow_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
triggered_from = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False)
@ -262,7 +261,7 @@ class WorkflowRun(db.Model):
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
total_steps = db.Column(db.Integer, server_default=db.text('0'))
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
finished_at = db.Column(db.DateTime)
@ -404,12 +403,12 @@ class WorkflowNodeExecution(db.Model):
'triggered_from', 'node_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=False)
workflow_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
workflow_id = db.Column(StringUUID, nullable=False)
triggered_from = db.Column(db.String(255), nullable=False)
workflow_run_id = db.Column(UUID)
workflow_run_id = db.Column(StringUUID)
index = db.Column(db.Integer, nullable=False)
predecessor_node_id = db.Column(db.String(255))
node_id = db.Column(db.String(255), nullable=False)
@ -424,7 +423,7 @@ class WorkflowNodeExecution(db.Model):
execution_metadata = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime)
@property
@ -529,14 +528,14 @@ class WorkflowAppLog(db.Model):
db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
app_id = db.Column(UUID, nullable=False)
workflow_id = db.Column(UUID, nullable=False)
workflow_run_id = db.Column(UUID, nullable=False)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
workflow_id = db.Column(StringUUID, nullable=False)
workflow_run_id = db.Column(StringUUID, nullable=False)
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property

View File

@ -1,7 +1,7 @@
beautifulsoup4==4.12.2
flask~=3.0.1
Flask-SQLAlchemy~=3.0.5
SQLAlchemy~=1.4.28
SQLAlchemy~=2.0.29
Flask-Compress~=1.14
flask-login~=0.6.3
flask-migrate~=4.0.5
@ -44,6 +44,7 @@ google-auth-httplib2==0.2.0
google-generativeai==0.5.0
google-search-results==2.4.2
googleapis-common-protos==1.63.0
google-cloud-storage==2.16.0
replicate~=0.22.0
websocket-client~=1.7.0
dashscope[tokenizer]~=1.17.0
@ -80,4 +81,5 @@ lxml==5.1.0
xlrd~=2.0.1
pydantic~=1.10.0
pgvecto-rs==0.1.4
oss2==2.15.0
firecrawl-py==0.0.5
oss2==2.15.0

View File

@ -74,5 +74,5 @@ class BillingService:
TenantAccountJoin.account_id == current_user.id
).first()
if TenantAccountRole.is_privileged_role(join.role):
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError('Only team owner or team admin can perform this action')

View File

@ -418,9 +418,8 @@ class DocumentService:
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
Document.indexing_status.in_(['error', 'paused'])
).all()
return documents
@staticmethod

View File

@ -1,38 +1,36 @@
import uuid
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
def test_milvus_vector(setup_mock_redis) -> None:
dataset_id = str(uuid.uuid4())
vector = MilvusVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
config=MilvusConfig(
host='localhost',
port=19530,
user='root',
password='Milvus',
class MilvusVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = MilvusVector(
collection_name=self.collection_name,
config=MilvusConfig(
host='localhost',
port=19530,
user='root',
password='Milvus',
)
)
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
def delete_by_document_id(self):
self.vector.delete_by_document_id(document_id=self.example_doc_id)
# milvus dos not support full text searching yet in < 2.3.x
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
assert len(ids) == 1
# delete vector
vector.delete()
def test_milvus_vector(setup_mock_redis):
MilvusVectorTest().run_all_tests()

View File

@ -0,0 +1,37 @@
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class TestPgvectoRSVector(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVectoRS(
collection_name=self.collection_name.lower(),
config=PgvectoRSConfig(
host='localhost',
port=5431,
user='postgres',
password='difyai123456',
database='dify',
),
dim=128
)
def search_by_full_text(self):
# pgvecto rs only support english text search, So its not open for now
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def delete_by_document_id(self):
self.vector.delete_by_document_id(document_id=self.example_doc_id)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
assert len(ids) == 1
def test_pgvecot_rs(setup_mock_redis):
TestPgvectoRSVector().run_all_tests()

View File

@ -1,40 +1,23 @@
import uuid
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
get_sample_text,
AbstractVectorTest,
setup_mock_redis,
)
def test_qdrant_vector(setup_mock_redis)-> None:
dataset_id = str(uuid.uuid4())
vector = QdrantVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
group_id=dataset_id,
config=QdrantConfig(
endpoint='http://localhost:6333',
api_key='difyai123456',
class QdrantVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.vector = QdrantVector(
collection_name=self.collection_name,
group_id=self.dataset_id,
config=QdrantConfig(
endpoint='http://localhost:6333',
api_key='difyai123456',
)
)
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
# search by full text
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
assert len(hits_by_full_text) >= 1
# delete vector
vector.delete()
def test_qdrant_vector(setup_mock_redis):
QdrantVectorTest().run_all_tests()

View File

@ -1,31 +1,26 @@
import random
import uuid
from unittest.mock import MagicMock
import pytest
from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset
def get_sample_text() -> str:
def get_example_text() -> str:
return 'test_text'
def get_sample_embedding() -> list[float]:
return [1.1, 2.2, 3.3]
def get_sample_query_vector() -> list[float]:
return get_sample_embedding()
def get_sample_document(sample_dataset_id: str) -> Document:
def get_example_document(doc_id: str) -> Document:
doc = Document(
page_content=get_sample_text(),
page_content=get_example_text(),
metadata={
"doc_id": sample_dataset_id,
"doc_hash": sample_dataset_id,
"document_id": sample_dataset_id,
"dataset_id": sample_dataset_id,
"doc_id": doc_id,
"doc_hash": doc_id,
"document_id": doc_id,
"dataset_id": doc_id,
}
)
return doc
@ -44,3 +39,63 @@ def setup_mock_redis() -> None:
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest:
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
def create_vector(self) -> None:
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
def search_by_vector(self):
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 1
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
def delete_vector(self):
self.vector.delete()
def delete_by_ids(self, ids: list[str]):
self.vector.delete_by_ids(ids=ids)
def add_texts(self) -> list[str]:
batch_size = 100
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
embeddings = [self.example_embedding] * batch_size
self.vector.add_texts(documents=documents, embeddings=embeddings)
return [doc.metadata['doc_id'] for doc in documents]
def text_exists(self):
assert self.vector.text_exists(self.example_doc_id)
def delete_by_document_id(self):
with pytest.raises(NotImplementedError):
self.vector.delete_by_document_id(document_id=self.example_doc_id)
def get_ids_by_metadata_field(self):
with pytest.raises(NotImplementedError):
self.vector.get_ids_by_metadata_field(key='key', value='value')
def run_all_tests(self):
self.create_vector()
self.search_by_vector()
self.search_by_full_text()
self.text_exists()
self.get_ids_by_metadata_field()
self.delete_by_document_id()
added_doc_ids = self.add_texts()
self.delete_by_ids(added_doc_ids)
self.delete_vector()

View File

@ -1,41 +1,23 @@
import uuid
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
get_sample_text,
AbstractVectorTest,
setup_mock_redis,
)
def test_weaviate_vector(setup_mock_redis) -> None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
dataset_id = str(uuid.uuid4())
vector = WeaviateVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
config=WeaviateConfig(
endpoint='http://localhost:8080',
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
),
attributes=attributes
)
class WeaviateVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.vector = WeaviateVector(
collection_name=self.collection_name,
config=WeaviateConfig(
endpoint='http://localhost:8080',
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
),
attributes=self.attributes
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
# search by full text
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
assert len(hits_by_full_text) >= 1
# delete vector
vector.delete()
def test_weaviate_vector(setup_mock_redis):
WeaviateVectorTest().run_all_tests()

View File

@ -114,7 +114,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
result = node.run(pool)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == 'result in output form must be a string'
assert result.error == 'Output variable `result` must be a string'
def test_execute_code_output_validator_depth():
code = '''

View File

@ -53,7 +53,7 @@ services:
# The DifySandbox
sandbox:
image: langgenius/dify-sandbox:latest
image: langgenius/dify-sandbox:0.1.0
restart: always
cap_add:
# Why is sys_admin permission needed?
@ -80,3 +80,4 @@ services:
# QDRANT_API_KEY: 'difyai123456'
# ports:
# - "6333:6333"
# - "6334:6334"

View File

@ -0,0 +1,24 @@
version: '3'
services:
# The pgvecto—rs database.
pgvecto-rs:
image: tensorchord/pgvecto-rs:pg16-v0.2.0
restart: always
environment:
PGUSER: postgres
# The password for the default postgres user.
POSTGRES_PASSWORD: difyai123456
# The name of the default postgres database.
POSTGRES_DB: dify
# postgres data directory
PGDATA: /var/lib/postgresql/data/pgdata
volumes:
- ./volumes/pgvectors/data:/var/lib/postgresql/data
# uncomment to expose db(postgresql) port to host
ports:
- "5431:5432"
healthcheck:
test: [ "CMD", "pg_isready" ]
interval: 1s
timeout: 3s
retries: 30

View File

@ -10,3 +10,4 @@ services:
QDRANT_API_KEY: 'difyai123456'
ports:
- "6333:6333"
- "6334:6334"

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.6.5
image: langgenius/dify-api:0.6.6
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -70,7 +70,7 @@ services:
# If you want to enable cross-origin support,
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
#
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob` and `google-storage`, Default: `local`
STORAGE_TYPE: local
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
# only available when STORAGE_TYPE is `local`.
@ -86,7 +86,10 @@ services:
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
# The Google storage configurations, only available when STORAGE_TYPE is `google-storage`.
GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name'
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080
@ -96,8 +99,12 @@ services:
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: difyai123456
# The Qdrant clinet timeout setting.
# The Qdrant client timeout setting.
QDRANT_CLIENT_TIMEOUT: 20
# The Qdrant client enable gRPC mode.
QDRANT_GRPC_ENABLED: 'false'
# The Qdrant server gRPC mode PORT.
QDRANT_GRPC_PORT: 6334
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
# The milvus host.
MILVUS_HOST: 127.0.0.1
@ -109,6 +116,12 @@ services:
MILVUS_PASSWORD: Milvus
# The milvus tls switch.
MILVUS_SECURE: 'false'
# relyt configurations
RELYT_HOST: db
RELYT_PORT: 5432
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
# Mail configuration, support: resend, smtp
MAIL_TYPE: ''
# default send from email address, if not specified
@ -127,6 +140,11 @@ services:
SENTRY_TRACES_SAMPLE_RATE: 1.0
# The sample rate for Sentry profiles. Default: `1.0`
SENTRY_PROFILES_SAMPLE_RATE: 1.0
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE: public
NOTION_CLIENT_SECRET: you-client-secret
NOTION_CLIENT_ID: you-client-id
NOTION_INTERNAL_SECRET: you-internal-secret
# The sandbox service endpoint.
CODE_EXECUTION_ENDPOINT: "http://sandbox:8194"
CODE_EXECUTION_API_KEY: dify-sandbox
@ -150,7 +168,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.6.5
image: langgenius/dify-api:0.6.6
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@ -193,7 +211,7 @@ services:
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080
@ -205,6 +223,10 @@ services:
QDRANT_API_KEY: difyai123456
# The Qdrant clinet timeout setting.
QDRANT_CLIENT_TIMEOUT: 20
# The Qdrant client enable gRPC mode.
QDRANT_GRPC_ENABLED: 'false'
# The Qdrant server gRPC mode PORT.
QDRANT_GRPC_PORT: 6334
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
# The milvus host.
MILVUS_HOST: 127.0.0.1
@ -229,6 +251,11 @@ services:
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
# Notion import configuration, support public and internal
NOTION_INTEGRATION_TYPE: public
NOTION_CLIENT_SECRET: you-client-secret
NOTION_CLIENT_ID: you-client-id
NOTION_INTERNAL_SECRET: you-internal-secret
depends_on:
- db
- redis
@ -238,10 +265,9 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.6.5
image: langgenius/dify-web:0.6.6
restart: always
environment:
EDITION: SELF_HOSTED
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.
# example: http://cloud.dify.ai
@ -320,7 +346,7 @@ services:
# The DifySandbox
sandbox:
image: langgenius/dify-sandbox:latest
image: langgenius/dify-sandbox:0.1.0
restart: always
cap_add:
# Why is sys_admin permission needed?
@ -346,6 +372,7 @@ services:
# # uncomment to expose qdrant port to host
# # ports:
# # - "6333:6333"
# # - "6334:6334"
# The nginx reverse proxy.
# used for reverse proxying the API service and Web service.

View File

@ -1,6 +1,6 @@
# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED or CLOUD
# The deployment edition, SELF_HOSTED
NEXT_PUBLIC_EDITION=SELF_HOSTED
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.

View File

@ -17,7 +17,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t
```
# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED or CLOUD
# The deployment edition, SELF_HOSTED
NEXT_PUBLIC_EDITION=SELF_HOSTED
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.

View File

@ -1,11 +1,12 @@
'use client'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import useSWRInfinite from 'swr/infinite'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'
import AppCard from './AppCard'
import NewAppCard from './NewAppCard'
import useAppsQueryState from './hooks/useAppsQueryState'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
@ -54,10 +55,15 @@ const Apps = () => {
const [activeTab, setActiveTab] = useTabSearchParams({
defaultTab: 'all',
})
const [tagFilterValue, setTagFilterValue] = useState<string[]>([])
const [tagIDs, setTagIDs] = useState<string[]>([])
const [keywords, setKeywords] = useState('')
const [searchKeywords, setSearchKeywords] = useState('')
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords)
const setKeywords = useCallback((keywords: string) => {
setQuery(prev => ({ ...prev, keywords }))
}, [setQuery])
const setTagIDs = useCallback((tagIDs: string[]) => {
setQuery(prev => ({ ...prev, tagIDs }))
}, [setQuery])
const { data, isLoading, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
@ -81,17 +87,18 @@ const Apps = () => {
}
}, [])
const hasMore = data?.at(-1)?.has_more ?? true
useEffect(() => {
let observer: IntersectionObserver | undefined
if (anchorRef.current) {
observer = new IntersectionObserver((entries) => {
if (entries[0].isIntersecting && !isLoading)
if (entries[0].isIntersecting && !isLoading && hasMore)
setSize((size: number) => size + 1)
}, { rootMargin: '100px' })
observer.observe(anchorRef.current)
}
return () => observer?.disconnect()
}, [isLoading, setSize, anchorRef, mutate])
}, [isLoading, setSize, anchorRef, mutate, hasMore])
const { run: handleSearch } = useDebounceFn(() => {
setSearchKeywords(keywords)

View File

@ -0,0 +1,53 @@
import { type ReadonlyURLSearchParams, usePathname, useRouter, useSearchParams } from 'next/navigation'
import { useCallback, useEffect, useMemo, useState } from 'react'
type AppsQuery = {
tagIDs?: string[]
keywords?: string
}
// Parse the query parameters from the URL search string.
function parseParams(params: ReadonlyURLSearchParams): AppsQuery {
const tagIDs = params.get('tagIDs')?.split(';')
const keywords = params.get('keywords') || undefined
return { tagIDs, keywords }
}
// Update the URL search string with the given query parameters.
function updateSearchParams(query: AppsQuery, current: URLSearchParams) {
const { tagIDs, keywords } = query || {}
if (tagIDs && tagIDs.length > 0)
current.set('tagIDs', tagIDs.join(';'))
else
current.delete('tagIDs')
if (keywords)
current.set('keywords', keywords)
else
current.delete('keywords')
}
function useAppsQueryState() {
const searchParams = useSearchParams()
const [query, setQuery] = useState<AppsQuery>(() => parseParams(searchParams))
const router = useRouter()
const pathname = usePathname()
const syncSearchParams = useCallback((params: URLSearchParams) => {
const search = params.toString()
const query = search ? `?${search}` : ''
router.push(`${pathname}${query}`)
}, [router, pathname])
// Update the URL search string whenever the query changes.
useEffect(() => {
const params = new URLSearchParams(searchParams)
updateSearchParams(query, params)
syncSearchParams(params)
}, [query, searchParams, syncSearchParams])
return useMemo(() => ({ query, setQuery }), [query])
}
export default useAppsQueryState

View File

@ -1,248 +1,248 @@
'use client'
import type { FC, SVGProps } from 'react'
import React, { useEffect } from 'react'
import { usePathname } from 'next/navigation'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import classNames from 'classnames'
import { useBoolean } from 'ahooks'
import {
Cog8ToothIcon,
// CommandLineIcon,
Squares2X2Icon,
// eslint-disable-next-line sort-imports
PuzzlePieceIcon,
DocumentTextIcon,
PaperClipIcon,
QuestionMarkCircleIcon,
} from '@heroicons/react/24/outline'
import {
Cog8ToothIcon as Cog8ToothSolidIcon,
// CommandLineIcon as CommandLineSolidIcon,
DocumentTextIcon as DocumentTextSolidIcon,
} from '@heroicons/react/24/solid'
import Link from 'next/link'
import s from './style.module.css'
import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets'
import type { RelatedApp, RelatedAppResponse } from '@/models/datasets'
import { getLocaleOnClient } from '@/i18n'
import AppSideBar from '@/app/components/app-sidebar'
import Divider from '@/app/components/base/divider'
import AppIcon from '@/app/components/base/app-icon'
import Loading from '@/app/components/base/loading'
import FloatPopoverContainer from '@/app/components/base/float-popover-container'
import DatasetDetailContext from '@/context/dataset-detail'
import { DataSourceType } from '@/models/datasets'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { LanguagesSupported } from '@/i18n/language'
import { useStore } from '@/app/components/app/store'
import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication'
import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
export type IAppDetailLayoutProps = {
children: React.ReactNode
params: { datasetId: string }
}
type ILikedItemProps = {
type?: 'plugin' | 'app'
appStatus?: boolean
detail: RelatedApp
isMobile: boolean
}
const LikedItem = ({
type = 'app',
detail,
isMobile,
}: ILikedItemProps) => {
return (
<Link className={classNames(s.itemWrapper, 'px-2', isMobile && 'justify-center')} href={`/app/${detail?.id}/overview`}>
<div className={classNames(s.iconWrapper, 'mr-0')}>
<AppIcon size='tiny' icon={detail?.icon} background={detail?.icon_background} />
{type === 'app' && (
<span className='absolute bottom-[-2px] right-[-2px] w-3.5 h-3.5 p-0.5 bg-white rounded border-[0.5px] border-[rgba(0,0,0,0.02)] shadow-sm'>
{detail.mode === 'advanced-chat' && (
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
)}
{detail.mode === 'agent-chat' && (
<CuteRobote className='w-2.5 h-2.5 text-indigo-600' />
)}
{detail.mode === 'chat' && (
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
)}
{detail.mode === 'completion' && (
<AiText className='w-2.5 h-2.5 text-[#0E9384]' />
)}
{detail.mode === 'workflow' && (
<Route className='w-2.5 h-2.5 text-[#f79009]' />
)}
</span>
)}
</div>
{!isMobile && <div className={classNames(s.appInfo, 'ml-2')}>{detail?.name || '--'}</div>}
</Link>
)
}
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_4610_6951)">
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
</g>
<defs>
<clipPath id="clip0_4610_6951">
<rect width="16" height="16" fill="white" />
</clipPath>
</defs>
</svg>
}
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
<path d="M5.33317 7.99984C5.33317 6.52708 6.52708 5.33317 7.99984 5.33317C8.36803 5.33317 8.6665 5.03469 8.6665 4.6665C8.6665 4.29831 8.36803 3.99984 7.99984 3.99984C5.7907 3.99984 3.99984 5.7907 3.99984 7.99984C3.99984 10.209 5.7907 11.9998 7.99984 11.9998C10.209 11.9998 11.9998 10.209 11.9998 7.99984C11.9998 7.63165 11.7014 7.33317 11.3332 7.33317C10.965 7.33317 10.6665 7.63165 10.6665 7.99984C10.6665 9.4726 9.4726 10.6665 7.99984 10.6665C6.52708 10.6665 5.33317 9.4726 5.33317 7.99984Z" fill="#155EEF" />
</svg>
}
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />
</svg>
}
type IExtraInfoProps = {
isMobile: boolean
relatedApps?: RelatedAppResponse
}
const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => {
const locale = getLocaleOnClient()
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
const { t } = useTranslation()
useEffect(() => {
setShowTips(!isMobile)
}, [isMobile, setShowTips])
return <div className='w-full flex flex-col items-center'>
<Divider className='mt-5' />
{(relatedApps?.data && relatedApps?.data?.length > 0) && (
<>
{!isMobile && <div className='w-full px-2 pb-1 pt-4 uppercase text-xs text-gray-500 font-medium'>{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}</div>}
{isMobile && <div className={classNames(s.subTitle, 'flex items-center justify-center !px-0 gap-1')}>
{relatedApps?.total || '--'}
<PaperClipIcon className='h-4 w-4 text-gray-700' />
</div>}
{relatedApps?.data?.map((item, index) => (<LikedItem key={index} isMobile={isMobile} detail={item} />))}
</>
)}
{!relatedApps?.data?.length && (
<FloatPopoverContainer
placement='bottom-start'
open={isShowTips}
toggle={toggleTips}
isMobile={isMobile}
triggerElement={
<div className={classNames('h-7 w-7 inline-flex justify-center items-center rounded-lg bg-transparent', isShowTips && '!bg-gray-50')}>
<QuestionMarkCircleIcon className='h-4 w-4 flex-shrink-0 text-gray-500' />
</div>
}
>
<div className={classNames('mt-5 p-3', isMobile && 'border-[0.5px] border-gray-200 shadow-lg rounded-lg bg-white w-[160px]')}>
<div className='flex items-center justify-start gap-2'>
<div className={s.emptyIconDiv}>
<Squares2X2Icon className='w-3 h-3 text-gray-500' />
</div>
<div className={s.emptyIconDiv}>
<PuzzlePieceIcon className='w-3 h-3 text-gray-500' />
</div>
</div>
<div className='text-xs text-gray-500 mt-2'>{t('common.datasetMenus.emptyTip')}</div>
<a
className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer'
href={
locale === LanguagesSupported[1]
? 'https://docs.dify.ai/v/zh-hans/guides/application-design/prompt-engineering'
: 'https://docs.dify.ai/user-guide/creating-dify-apps/prompt-engineering'
}
target='_blank' rel='noopener noreferrer'
>
<BookOpenIcon className='mr-1' />
{t('common.datasetMenus.viewDoc')}
</a>
</div>
</FloatPopoverContainer>
)}
</div>
}
const DatasetDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
const {
children,
params: { datasetId },
} = props
const pathname = usePathname()
const hideSideBar = /documents\/create$/.test(pathname)
const { t } = useTranslation()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const { data: datasetRes, error, mutate: mutateDatasetRes } = useSWR({
url: 'fetchDatasetDetail',
datasetId,
}, apiParams => fetchDatasetDetail(apiParams.datasetId))
const { data: relatedApps } = useSWR({
action: 'fetchDatasetRelatedApps',
datasetId,
}, apiParams => fetchDatasetRelatedApps(apiParams.datasetId))
const navigation = [
{ name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon },
{ name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: TargetIcon, selectedIcon: TargetSolidIcon },
// { name: 'api & webhook', href: `/datasets/${datasetId}/api`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon },
{ name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon },
]
useEffect(() => {
if (datasetRes)
document.title = `${datasetRes.name || 'Dataset'} - Dify`
}, [datasetRes])
const setAppSiderbarExpand = useStore(state => state.setAppSiderbarExpand)
useEffect(() => {
const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand'
const mode = isMobile ? 'collapse' : 'expand'
setAppSiderbarExpand(isMobile ? mode : localeMode)
}, [isMobile, setAppSiderbarExpand])
if (!datasetRes && !error)
return <Loading />
return (
<div className='grow flex overflow-hidden'>
{!hideSideBar && <AppSideBar
title={datasetRes?.name || '--'}
icon={datasetRes?.icon || 'https://static.dify.ai/images/dataset-default-icon.png'}
icon_background={datasetRes?.icon_background || '#F5F5F5'}
desc={datasetRes?.description || '--'}
navigation={navigation}
extraInfo={mode => <ExtraInfo isMobile={mode === 'collapse'} relatedApps={relatedApps} />}
iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'}
/>}
<DatasetDetailContext.Provider value={{
indexingTechnique: datasetRes?.indexing_technique,
dataset: datasetRes,
mutateDatasetRes: () => mutateDatasetRes(),
}}>
<div className="bg-white grow overflow-hidden">{children}</div>
</DatasetDetailContext.Provider>
</div>
)
}
export default React.memo(DatasetDetailLayout)
'use client'
import type { FC, SVGProps } from 'react'
import React, { useEffect } from 'react'
import { usePathname } from 'next/navigation'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import classNames from 'classnames'
import { useBoolean } from 'ahooks'
import {
Cog8ToothIcon,
// CommandLineIcon,
Squares2X2Icon,
// eslint-disable-next-line sort-imports
PuzzlePieceIcon,
DocumentTextIcon,
PaperClipIcon,
QuestionMarkCircleIcon,
} from '@heroicons/react/24/outline'
import {
Cog8ToothIcon as Cog8ToothSolidIcon,
// CommandLineIcon as CommandLineSolidIcon,
DocumentTextIcon as DocumentTextSolidIcon,
} from '@heroicons/react/24/solid'
import Link from 'next/link'
import s from './style.module.css'
import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets'
import type { RelatedApp, RelatedAppResponse } from '@/models/datasets'
import AppSideBar from '@/app/components/app-sidebar'
import Divider from '@/app/components/base/divider'
import AppIcon from '@/app/components/base/app-icon'
import Loading from '@/app/components/base/loading'
import FloatPopoverContainer from '@/app/components/base/float-popover-container'
import DatasetDetailContext from '@/context/dataset-detail'
import { DataSourceType } from '@/models/datasets'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { LanguagesSupported } from '@/i18n/language'
import { useStore } from '@/app/components/app/store'
import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication'
import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
import { getLocaleOnClient } from '@/i18n'
export type IAppDetailLayoutProps = {
children: React.ReactNode
params: { datasetId: string }
}
type ILikedItemProps = {
type?: 'plugin' | 'app'
appStatus?: boolean
detail: RelatedApp
isMobile: boolean
}
const LikedItem = ({
type = 'app',
detail,
isMobile,
}: ILikedItemProps) => {
return (
<Link className={classNames(s.itemWrapper, 'px-2', isMobile && 'justify-center')} href={`/app/${detail?.id}/overview`}>
<div className={classNames(s.iconWrapper, 'mr-0')}>
<AppIcon size='tiny' icon={detail?.icon} background={detail?.icon_background} />
{type === 'app' && (
<span className='absolute bottom-[-2px] right-[-2px] w-3.5 h-3.5 p-0.5 bg-white rounded border-[0.5px] border-[rgba(0,0,0,0.02)] shadow-sm'>
{detail.mode === 'advanced-chat' && (
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
)}
{detail.mode === 'agent-chat' && (
<CuteRobote className='w-2.5 h-2.5 text-indigo-600' />
)}
{detail.mode === 'chat' && (
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
)}
{detail.mode === 'completion' && (
<AiText className='w-2.5 h-2.5 text-[#0E9384]' />
)}
{detail.mode === 'workflow' && (
<Route className='w-2.5 h-2.5 text-[#f79009]' />
)}
</span>
)}
</div>
{!isMobile && <div className={classNames(s.appInfo, 'ml-2')}>{detail?.name || '--'}</div>}
</Link>
)
}
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_4610_6951)">
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
</g>
<defs>
<clipPath id="clip0_4610_6951">
<rect width="16" height="16" fill="white" />
</clipPath>
</defs>
</svg>
}
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
<path d="M5.33317 7.99984C5.33317 6.52708 6.52708 5.33317 7.99984 5.33317C8.36803 5.33317 8.6665 5.03469 8.6665 4.6665C8.6665 4.29831 8.36803 3.99984 7.99984 3.99984C5.7907 3.99984 3.99984 5.7907 3.99984 7.99984C3.99984 10.209 5.7907 11.9998 7.99984 11.9998C10.209 11.9998 11.9998 10.209 11.9998 7.99984C11.9998 7.63165 11.7014 7.33317 11.3332 7.33317C10.965 7.33317 10.6665 7.63165 10.6665 7.99984C10.6665 9.4726 9.4726 10.6665 7.99984 10.6665C6.52708 10.6665 5.33317 9.4726 5.33317 7.99984Z" fill="#155EEF" />
</svg>
}
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />
</svg>
}
type IExtraInfoProps = {
isMobile: boolean
relatedApps?: RelatedAppResponse
}
const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => {
const locale = getLocaleOnClient()
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
const { t } = useTranslation()
useEffect(() => {
setShowTips(!isMobile)
}, [isMobile, setShowTips])
return <div className='w-full flex flex-col items-center'>
<Divider className='mt-5' />
{(relatedApps?.data && relatedApps?.data?.length > 0) && (
<>
{!isMobile && <div className='w-full px-2 pb-1 pt-4 uppercase text-xs text-gray-500 font-medium'>{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}</div>}
{isMobile && <div className={classNames(s.subTitle, 'flex items-center justify-center !px-0 gap-1')}>
{relatedApps?.total || '--'}
<PaperClipIcon className='h-4 w-4 text-gray-700' />
</div>}
{relatedApps?.data?.map((item, index) => (<LikedItem key={index} isMobile={isMobile} detail={item} />))}
</>
)}
{!relatedApps?.data?.length && (
<FloatPopoverContainer
placement='bottom-start'
open={isShowTips}
toggle={toggleTips}
isMobile={isMobile}
triggerElement={
<div className={classNames('h-7 w-7 inline-flex justify-center items-center rounded-lg bg-transparent', isShowTips && '!bg-gray-50')}>
<QuestionMarkCircleIcon className='h-4 w-4 flex-shrink-0 text-gray-500' />
</div>
}
>
<div className={classNames('mt-5 p-3', isMobile && 'border-[0.5px] border-gray-200 shadow-lg rounded-lg bg-white w-[160px]')}>
<div className='flex items-center justify-start gap-2'>
<div className={s.emptyIconDiv}>
<Squares2X2Icon className='w-3 h-3 text-gray-500' />
</div>
<div className={s.emptyIconDiv}>
<PuzzlePieceIcon className='w-3 h-3 text-gray-500' />
</div>
</div>
<div className='text-xs text-gray-500 mt-2'>{t('common.datasetMenus.emptyTip')}</div>
<a
className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer'
href={
locale === LanguagesSupported[1]
? 'https://docs.dify.ai/v/zh-hans/guides/application-design/prompt-engineering'
: 'https://docs.dify.ai/user-guide/creating-dify-apps/prompt-engineering'
}
target='_blank' rel='noopener noreferrer'
>
<BookOpenIcon className='mr-1' />
{t('common.datasetMenus.viewDoc')}
</a>
</div>
</FloatPopoverContainer>
)}
</div>
}
const DatasetDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
const {
children,
params: { datasetId },
} = props
const pathname = usePathname()
const hideSideBar = /documents\/create$/.test(pathname)
const { t } = useTranslation()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const { data: datasetRes, error, mutate: mutateDatasetRes } = useSWR({
url: 'fetchDatasetDetail',
datasetId,
}, apiParams => fetchDatasetDetail(apiParams.datasetId))
const { data: relatedApps } = useSWR({
action: 'fetchDatasetRelatedApps',
datasetId,
}, apiParams => fetchDatasetRelatedApps(apiParams.datasetId))
const navigation = [
{ name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon },
{ name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: TargetIcon, selectedIcon: TargetSolidIcon },
// { name: 'api & webhook', href: `/datasets/${datasetId}/api`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon },
{ name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon },
]
useEffect(() => {
if (datasetRes)
document.title = `${datasetRes.name || 'Dataset'} - Dify`
}, [datasetRes])
const setAppSiderbarExpand = useStore(state => state.setAppSiderbarExpand)
useEffect(() => {
const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand'
const mode = isMobile ? 'collapse' : 'expand'
setAppSiderbarExpand(isMobile ? mode : localeMode)
}, [isMobile, setAppSiderbarExpand])
if (!datasetRes && !error)
return <Loading />
return (
<div className='grow flex overflow-hidden'>
{!hideSideBar && <AppSideBar
title={datasetRes?.name || '--'}
icon={datasetRes?.icon || 'https://static.dify.ai/images/dataset-default-icon.png'}
icon_background={datasetRes?.icon_background || '#F5F5F5'}
desc={datasetRes?.description || '--'}
navigation={navigation}
extraInfo={mode => <ExtraInfo isMobile={mode === 'collapse'} relatedApps={relatedApps} />}
iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'}
/>}
<DatasetDetailContext.Provider value={{
indexingTechnique: datasetRes?.indexing_technique,
dataset: datasetRes,
mutateDatasetRes: () => mutateDatasetRes(),
}}>
<div className="bg-white grow overflow-hidden">{children}</div>
</DatasetDetailContext.Provider>
</div>
)
}
export default React.memo(DatasetDetailLayout)

View File

@ -86,7 +86,7 @@ const ActivateForm = () => {
timezone,
},
})
setLocaleOnClient(language.startsWith('en') ? 'en' : 'zh-Hans', false)
setLocaleOnClient(language.startsWith('en') ? 'en-US' : 'zh-Hans', false)
setShowSuccess(true)
}
catch {

View File

@ -362,7 +362,7 @@ const Answer: FC<IAnswerProps> = ({
{!item.isOpeningStatement && (
<CopyBtn
value={content}
className={cn(s.copyBtn, 'mr-1')}
className='mr-1'
/>
)}
{((isShowPromptLog && !isResponding) || (!item.isOpeningStatement && isShowTextToSpeech)) && (

View File

@ -8,9 +8,8 @@ import { useParams } from 'next/navigation'
import { HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
import { useBoolean } from 'ahooks'
import { HashtagIcon } from '@heroicons/react/24/solid'
// import PromptLog from '@/app/components/app/chat/log'
import ResultTab from './result-tab'
import { Markdown } from '@/app/components/base/markdown'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import Loading from '@/app/components/base/loading'
import Toast from '@/app/components/base/toast'
import AudioBtn from '@/app/components/base/audio-btn'
@ -26,7 +25,6 @@ import EditReplyModal from '@/app/components/app/annotation/edit-annotation-moda
import { useStore as useAppStore } from '@/app/components/app/store'
import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
const MAX_DEPTH = 3
@ -293,23 +291,17 @@ const GenerationItem: FC<IGenerationItemProps> = ({
<div className={`flex ${contentClassName}`}>
<div className='grow w-0'>
{workflowProcessData && (
<WorkflowProcessItem grayBg data={workflowProcessData} expand={workflowProcessData.expand} />
<WorkflowProcessItem grayBg hideInfo data={workflowProcessData} expand={workflowProcessData.expand} />
)}
{workflowProcessData && !isError && (
<ResultTab data={workflowProcessData} content={content} />
)}
{isError && (
<div className='text-gray-400 text-sm'>{t('share.generation.batchFailed.outputPlaceholder')}</div>
)}
{!isError && (typeof content === 'string') && (
{!workflowProcessData && !isError && (typeof content === 'string') && (
<Markdown content={content} />
)}
{!isError && (typeof content !== 'string') && (
<CodeEditor
readOnly
title={<div/>}
language={CodeLanguage.json}
value={content}
isJSONStringifyBeauty
/>
)}
</div>
</div>
@ -427,7 +419,11 @@ const GenerationItem: FC<IGenerationItemProps> = ({
</>
)}
</div>
<div className='text-xs text-gray-500'>{content?.length} {t('common.unit.char')}</div>
<div>
{!workflowProcessData && (
<div className='text-xs text-gray-500'>{content?.length} {t('common.unit.char')}</div>
)}
</div>
</div>
</div>

Some files were not shown because too many files have changed in this diff Show More