mirror of
https://github.com/langgenius/dify.git
synced 2026-02-05 19:25:32 +08:00
Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 93393e005e | |||
| 4ea2755fce | |||
| ecb51a83d4 | |||
| 093b5c0e63 | |||
| bf42b0ae44 | |||
| 342b4fd19d | |||
| cbdb861ee4 | |||
| da5a8b9a59 | |||
| 1e6e8b446d | |||
| c1fdaa6ae0 | |||
| 142814d451 | |||
| 704755d005 | |||
| d1263700c0 | |||
| 0704fe9695 | |||
| 1d3f1d88ef | |||
| 8b3edac091 | |||
| 05cab85579 | |||
| b72fbe200d | |||
| b1194da6a5 | |||
| 338e4669e5 | |||
| c5e2659771 | |||
| 1d432728ac | |||
| 2fd702a319 | |||
| f26ad16af7 | |||
| 8f2ae51fe5 | |||
| 2f84d00300 | |||
| b82a2d97ef | |||
| 3e9dbe3e0a | |||
| 975b2fb79e | |||
| fa509ce64e | |||
| 99292edd46 | |||
| 3e992cb23c | |||
| e7b4d024ee | |||
| ff67a6d338 | |||
| 8e4989ed03 | |||
| 0940f01634 | |||
| 9d1cb1bc92 | |||
| 0ca4e30b19 | |||
| ba88f8a6f0 | |||
| aefe0cbf51 | |||
| 9ad489d133 | |||
| 661b30784e | |||
| 43a5ba9415 | |||
| 08a65d74d5 | |||
| cefe156811 | |||
| 3b5b4d628b | |||
| 8746e48df0 | |||
| 0ec8b57825 | |||
| 045827043d | |||
| 4d66a86579 | |||
| 2a8881d0e8 | |||
| ffc60bb917 | |||
| 2e454c770b | |||
| 7d711135bc | |||
| f62b2b5b45 | |||
| 7919596a21 | |||
| 9b4898efeb | |||
| 45dd1683fd | |||
| 8bca908f15 | |||
| 9cbb8ddd7f | |||
| 1be222af2e | |||
| bf9fc8fef4 | |||
| 86e7330fa2 |
@ -32,8 +32,8 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"postStartCommand": "cd api && pip install -r requirements.txt",
|
||||
"postCreateCommand": "cd web && npm install"
|
||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
10
.devcontainer/post_create_command.sh
Executable file
10
.devcontainer/post_create_command.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
3
.devcontainer/post_start_command.sh
Executable file
3
.devcontainer/post_start_command.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd api && pip install -r requirements.txt
|
||||
58
.github/workflows/api-tests.yml
vendored
58
.github/workflows/api-tests.yml
vendored
@ -14,50 +14,10 @@ jobs:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
|
||||
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
CHATGLM_API_BASE: http://a.abc.com:11451
|
||||
XINFERENCE_SERVER_URL: http://a.abc.com:11451
|
||||
XINFERENCE_GENERATION_MODEL_UID: generate
|
||||
XINFERENCE_CHAT_MODEL_UID: chat
|
||||
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
|
||||
XINFERENCE_RERANK_MODEL_UID: rerank
|
||||
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||
MOCK_SWITCH: true
|
||||
CODE_MAX_STRING_LENGTH: 80000
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Weaviate
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.middleware.yaml
|
||||
services: weaviate
|
||||
|
||||
- name: Set up Qdrant
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.qdrant.yaml
|
||||
services: qdrant
|
||||
|
||||
- name: Set up Milvus
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.milvus.yaml
|
||||
services: |
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -82,5 +42,21 @@ jobs:
|
||||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Vector Stores
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
docker/docker-compose.qdrant.yaml
|
||||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
# Server Edition
|
||||
EDITION=SELF_HOSTED
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
@ -57,12 +54,15 @@ ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-string
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -75,6 +75,8 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
QDRANT_CLIENT_TIMEOUT=20
|
||||
QDRANT_GRPC_ENABLED=false
|
||||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
@ -90,6 +92,13 @@ RELYT_USER=postgres
|
||||
RELYT_PASSWORD=postgres
|
||||
RELYT_DATABASE=postgres
|
||||
|
||||
# PGVECTO_RS configuration
|
||||
PGVECTO_RS_HOST=localhost
|
||||
PGVECTO_RS_PORT=5431
|
||||
PGVECTO_RS_USER=postgres
|
||||
PGVECTO_RS_PASSWORD=difyai123456
|
||||
PGVECTO_RS_DATABASE=postgres
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
@ -123,25 +132,6 @@ NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
# Hosted Model Credentials
|
||||
HOSTED_OPENAI_API_KEY=
|
||||
HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_TRIAL_ENABLED=false
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
HOSTED_AZURE_OPENAI_API_BASE=
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED=false
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
|
||||
@ -166,5 +156,10 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# HTTP Node configuration
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
LOG_FILE=
|
||||
|
||||
20
api/app.py
20
api/app.py
@ -1,28 +1,28 @@
|
||||
import os
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
# if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
from config import Config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
@ -75,16 +75,9 @@ config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def create_app(test_config=None) -> Flask:
|
||||
def create_app() -> Flask:
|
||||
app = DifyApp(__name__)
|
||||
|
||||
if test_config:
|
||||
app.config.from_object(test_config)
|
||||
else:
|
||||
if config_type == "CLOUD":
|
||||
app.config.from_object(CloudEditionConfig())
|
||||
else:
|
||||
app.config.from_object(Config())
|
||||
app.config.from_object(Config())
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
|
||||
@ -101,6 +94,7 @@ def create_app(test_config=None) -> Flask:
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
|
||||
@ -5,6 +5,7 @@ import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
DEFAULTS = {
|
||||
'EDITION': 'SELF_HOSTED',
|
||||
'DB_USERNAME': 'postgres',
|
||||
'DB_PASSWORD': '',
|
||||
'DB_HOST': 'localhost',
|
||||
@ -36,6 +37,8 @@ DEFAULTS = {
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'QDRANT_CLIENT_TIMEOUT': 20,
|
||||
'QDRANT_GRPC_ENABLED': 'False',
|
||||
'QDRANT_GRPC_PORT': '6334',
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'LOG_FILE': '',
|
||||
@ -104,9 +107,9 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.5"
|
||||
self.CURRENT_VERSION = "0.6.6"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.EDITION = get_env('EDITION')
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
@ -212,6 +215,8 @@ class Config:
|
||||
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
|
||||
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
|
||||
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
|
||||
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
|
||||
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
@ -223,6 +228,8 @@ class Config:
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
|
||||
self.QDRANT_GRPC_ENABLED = get_env('QDRANT_GRPC_ENABLED')
|
||||
self.QDRANT_GRPC_PORT = get_env('QDRANT_GRPC_PORT')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
@ -245,6 +252,13 @@ class Config:
|
||||
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
|
||||
self.RELYT_DATABASE = get_env('RELYT_DATABASE')
|
||||
|
||||
# pgvecto rs settings
|
||||
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
|
||||
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
|
||||
self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
|
||||
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
|
||||
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
@ -260,7 +274,7 @@ class Config:
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
# Workspace Configurations.
|
||||
# ------------------------
|
||||
self.INVITE_EXPIRY_HOURS = int(get_env('INVITE_EXPIRY_HOURS'))
|
||||
|
||||
@ -299,6 +313,12 @@ class Config:
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
|
||||
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
@ -345,17 +365,3 @@ class Config:
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.EDITION = "CLOUD"
|
||||
|
||||
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
|
||||
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
|
||||
languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'zh-Hant': 'Asia/Taipei',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
|
||||
@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type == 'milvus':
|
||||
if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
|
||||
if vector_type == 'milvus':
|
||||
if vector_type == 'milvus' or vector_type == 'relyt':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
|
||||
@ -394,9 +394,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
response = {
|
||||
"tokens": 0,
|
||||
|
||||
@ -28,9 +28,9 @@ from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamGenerateRoute,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -343,7 +343,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@ -366,7 +366,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = StreamGenerateRoute(
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
@ -430,15 +430,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(token)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += token
|
||||
yield self._message_to_stream_response(token, self._message.id)
|
||||
time.sleep(0.01)
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -463,10 +462,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
self._task_state.answer += token
|
||||
yield self._message_to_stream_response(token, self._message.id)
|
||||
time.sleep(0.01)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
@ -28,11 +28,13 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -40,6 +42,7 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -167,6 +171,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -174,6 +186,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
@ -193,6 +206,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
@ -254,3 +272,142 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
@ -119,7 +120,15 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
pass
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
|
||||
@ -9,9 +9,17 @@ from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
|
||||
|
||||
class StreamGenerateRoute(BaseModel):
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
StreamGenerateRoute entity
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
@ -55,6 +63,8 @@ class WorkflowTaskState(TaskState):
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
@ -62,7 +72,7 @@ class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[StreamGenerateRoute] = None
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
|
||||
@ -6,7 +6,7 @@ from yarl import URL
|
||||
|
||||
from config import get_env
|
||||
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python_transformer import PythonTemplateTransformer
|
||||
|
||||
# Code Executor
|
||||
|
||||
@ -55,6 +55,7 @@ if __name__ == '__main__':
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
@ -8,6 +8,8 @@
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
|
||||
@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
|
||||
model_name = model.split('.')[1]
|
||||
if isinstance(messages, str):
|
||||
prompt = messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix)
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "meta":
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
# LLAMA3
|
||||
if model_name.startswith("llama3"):
|
||||
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
else:
|
||||
# LLAMA2
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
self._convert_one_message_to_text(message, model_prefix, model_name)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload["textGenerationConfig"] = { **model_parameters }
|
||||
@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
elif model_prefix == "meta":
|
||||
payload = { **model_parameters }
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||
@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
|
||||
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
|
||||
|
||||
# need workaround for ai21 models which doesn't support streaming
|
||||
if stream and model_prefix != "ai21":
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-70b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 70B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00265'
|
||||
output: '0.0035'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-8b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 8B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.0004'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,37 @@
|
||||
model: abab6.5-chat
|
||||
label:
|
||||
en_US: Abab6.5-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.03'
|
||||
output: '0.03'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,37 @@
|
||||
model: abab6.5s-chat
|
||||
label:
|
||||
en_US: Abab6.5s-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 245760
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 245760
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.01'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -1,7 +1,7 @@
|
||||
- google/gemma-7b
|
||||
- google/codegemma-7b
|
||||
- meta/llama2-70b
|
||||
- meta/llama3-8b
|
||||
- meta/llama3-70b
|
||||
- meta/llama3-8b-instruct
|
||||
- meta/llama3-70b-instruct
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
- fuyu-8b
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
model: meta/llama3-70b
|
||||
model: meta/llama3-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta/llama3-70b
|
||||
en_US: meta/llama3-70b
|
||||
zh_Hans: meta/llama3-70b-instruct
|
||||
en_US: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
model: meta/llama3-8b
|
||||
model: meta/llama3-8b-instruct
|
||||
label:
|
||||
zh_Hans: meta/llama3-8b
|
||||
en_US: meta/llama3-8b
|
||||
zh_Hans: meta/llama3-8b-instruct
|
||||
en_US: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@ -26,8 +26,8 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
'google/gemma-7b': '',
|
||||
'google/codegemma-7b': '',
|
||||
'meta/llama2-70b': '',
|
||||
'meta/llama3-8b': '',
|
||||
'meta/llama3-70b': ''
|
||||
'meta/llama3-8b-instruct': '',
|
||||
'meta/llama3-70b-instruct': ''
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
version = credentials['model_version']
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
inputs = {**model_parameters}
|
||||
|
||||
@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
if model.count("/") != 1:
|
||||
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
|
||||
'format: {user_name}/{model_name}')
|
||||
|
||||
version = credentials['model_version']
|
||||
|
||||
try:
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
self._check_text_generation_model(model_info_version, model, version)
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def _check_text_generation_model(model_info_version, model_name, version):
|
||||
def _check_text_generation_model(model_info_version, model_name, version, description):
|
||||
if 'language model' in description.lower():
|
||||
return
|
||||
|
||||
if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
||||
or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
||||
or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
|
||||
@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
|
||||
version = credentials['model_version']
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
parameter_rules = []
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: Model Version
|
||||
type: text-input
|
||||
required: true
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
zh_Hans: 在此输入您的模型版本,默认为最新版本
|
||||
en_US: Enter your model version, default to the latest version
|
||||
|
||||
@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
replicate_model_version = f'{model}:{credentials["model_version"]}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f'{model}:{model_version}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
|
||||
texts)
|
||||
@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
|
||||
try:
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
replicate_model_version = f'{model}:{credentials["model_version"]}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f'{model}:{model_version}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
|
||||
['Hello worlds!'])
|
||||
|
||||
@ -1,9 +1,23 @@
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
REPETITION_PENALTY = "repetition_penalty"
|
||||
TOP_K = "top_k"
|
||||
features = []
|
||||
|
||||
return super().get_customizable_model_schema(model, cred_with_endpoint)
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")),
|
||||
ModelPropertyKey.MODE: cred_with_endpoint.get('mode'),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(cred_with_endpoint.get('temperature', 0.7)),
|
||||
min=0,
|
||||
max=2,
|
||||
precision=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(cred_with_endpoint.get('top_p', 1)),
|
||||
min=0,
|
||||
max=1,
|
||||
precision=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=TOP_K,
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(cred_with_endpoint.get('top_k', 50)),
|
||||
min=-2147483647,
|
||||
max=2147483647,
|
||||
precision=0
|
||||
),
|
||||
ParameterRule(
|
||||
name=REPETITION_PENALTY,
|
||||
label=I18nObject(en_US="Repetition Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(cred_with_endpoint.get('repetition_penalty', 1)),
|
||||
min=-3.4,
|
||||
max=3.4,
|
||||
precision=1
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)),
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('frequency_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="Presence Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('presence_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
)
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
||||
output=Decimal(cred_with_endpoint.get('output_price', 0)),
|
||||
unit=Decimal(cred_with_endpoint.get('unit', 0)),
|
||||
currency=cred_with_endpoint.get('currency', "USD")
|
||||
),
|
||||
)
|
||||
|
||||
if cred_with_endpoint['mode'] == 'chat':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif cred_with_endpoint['mode'] == 'completion':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}")
|
||||
|
||||
return entity
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
||||
@ -32,3 +32,8 @@ parameter_rules:
|
||||
zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。
|
||||
en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
|
||||
@ -124,7 +124,7 @@ class MilvusVector(BaseVector):
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
@ -136,7 +136,7 @@ class MilvusVector(BaseVector):
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
filter=f'metadata["doc_id"] in {ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
|
||||
0
api/core/rag/datasource/vdb/pgvecto_rs/__init__.py
Normal file
0
api/core/rag/datasource/vdb/pgvecto_rs/__init__.py
Normal file
12
api/core/rag/datasource/vdb/pgvecto_rs/collection.py
Normal file
12
api/core/rag/datasource/vdb/pgvecto_rs/collection.py
Normal file
@ -0,0 +1,12 @@
|
||||
from uuid import UUID
|
||||
|
||||
from numpy import ndarray
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped
|
||||
|
||||
|
||||
class CollectionORM(DeclarativeBase):
|
||||
__tablename__: str
|
||||
id: Mapped[UUID]
|
||||
text: Mapped[str]
|
||||
meta: Mapped[dict]
|
||||
vector: Mapped[ndarray]
|
||||
224
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
Normal file
224
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
Normal file
@ -0,0 +1,224 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from pydantic import BaseModel, root_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PgvectoRSConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config PGVECTO_RS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config PGVECTO_RS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config PGVECTO_RS_PASSWORD is required")
|
||||
if not values['database']:
|
||||
raise ValueError("config PGVECTO_RS_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class PGVectoRS(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True} # noqa: RUF012
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
text: Mapped[str] = mapped_column(String)
|
||||
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
|
||||
vector: Mapped[ndarray] = mapped_column(Vector(dim))
|
||||
|
||||
self._table = _Table
|
||||
self._distance_op = "<=>"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'pgvecto-rs'
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.create_collection(len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client) as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
vector vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
CREATE INDEX IF NOT EXISTS {index_name}
|
||||
ON {self._collection_name} USING vectors(vector vector_l2_ops)
|
||||
WITH (options = $$
|
||||
optimizing.optimizing_threads = 30
|
||||
segment.max_growing_segment_size = 2000
|
||||
segment.max_sealed_segment_size = 30000000
|
||||
[indexing.hnsw]
|
||||
m=30
|
||||
ef_construction=500
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
pks = []
|
||||
with Session(self._client) as session:
|
||||
for document, embedding in zip(documents, embeddings):
|
||||
pk = uuid4()
|
||||
session.execute(
|
||||
insert(self._table).values(
|
||||
id=pk,
|
||||
text=document.page_content,
|
||||
meta=document.metadata,
|
||||
vector=embedding,
|
||||
),
|
||||
)
|
||||
pks.append(pk)
|
||||
session.commit()
|
||||
|
||||
return pks
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.commit()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
|
||||
)
|
||||
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self) -> None:
|
||||
with Session(self._client) as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
with Session(self._client) as session:
|
||||
stmt = (
|
||||
select(
|
||||
self._table,
|
||||
self._table.vector.op(self._distance_op, return_type=Float)(
|
||||
query_vector,
|
||||
).label("distance"),
|
||||
)
|
||||
.limit(kwargs.get('top_k', 2))
|
||||
.order_by("distance")
|
||||
)
|
||||
res = session.execute(stmt)
|
||||
results = [(row[0], row[1]) for row in res]
|
||||
|
||||
# Organize results.
|
||||
docs = []
|
||||
for record, dis in results:
|
||||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata['score'] = score
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text,
|
||||
metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# with Session(self._client) as session:
|
||||
# select_statement = sql_text(
|
||||
# f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
|
||||
# )
|
||||
# results = session.execute(select_statement).fetchall()
|
||||
# if results:
|
||||
# docs = []
|
||||
# for result in results:
|
||||
# doc = Document(page_content=result[0],
|
||||
# metadata=result[1])
|
||||
# docs.append(doc)
|
||||
# return docs
|
||||
return []
|
||||
@ -36,6 +36,8 @@ class QdrantConfig(BaseModel):
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
@ -51,7 +53,9 @@ class QdrantConfig(BaseModel):
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout,
|
||||
'verify': self.endpoint.startswith('https')
|
||||
'verify': self.endpoint.startswith('https'),
|
||||
'grpc_port': self.grpc_port,
|
||||
'prefer_grpc': self.prefer_grpc
|
||||
}
|
||||
|
||||
|
||||
@ -113,8 +117,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
# create payload index
|
||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
# creat full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
|
||||
@ -1,16 +1,23 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from pgvecto_rs.sdk import PGVectoRs, Record
|
||||
from pydantic import BaseModel, root_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
class RelytConfig(BaseModel):
|
||||
host: str
|
||||
@ -36,16 +43,14 @@ class RelytConfig(BaseModel):
|
||||
|
||||
class RelytVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: RelytConfig, dim: int):
|
||||
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
|
||||
super().__init__(collection_name)
|
||||
self.embedding_dimension = 1536
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._client = PGVectoRs(
|
||||
db_url=self._url,
|
||||
collection_name=self._collection_name,
|
||||
dimension=dim
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'relyt'
|
||||
@ -54,6 +59,7 @@ class RelytVector(BaseVector):
|
||||
index_params = {}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(len(embeddings[0]))
|
||||
self.embedding_dimension = len(embeddings[0])
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
@ -63,21 +69,21 @@ class RelytVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self._client._engine) as session:
|
||||
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")
|
||||
with Session(self.client) as session:
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
|
||||
id TEXT PRIMARY KEY,
|
||||
document TEXT NOT NULL,
|
||||
metadata JSON NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
CREATE INDEX {index_name}
|
||||
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops)
|
||||
ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
|
||||
WITH (options = $$
|
||||
optimizing.optimizing_threads = 30
|
||||
segment.max_growing_segment_size = 2000
|
||||
@ -92,21 +98,62 @@ class RelytVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)]
|
||||
pks = [str(r.id) for r in records]
|
||||
self._client.insert(records)
|
||||
return pks
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
for metadata in metadatas:
|
||||
metadata['group_id'] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
# Define the table schema
|
||||
chunks_table = Table(
|
||||
self._collection_name,
|
||||
Base.metadata,
|
||||
Column("id", TEXT, primary_key=True),
|
||||
Column("embedding", Vector(len(embeddings[0]))),
|
||||
Column("document", String, nullable=True),
|
||||
Column("metadata", JSON, nullable=True),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
chunks_table_data = []
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"embedding": embedding,
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
|
||||
return ids
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||
if ids:
|
||||
self._client.delete_by_ids(ids)
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client._engine) as session:
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; "
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
@ -114,56 +161,140 @@ class RelytVector(BaseVector):
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_uuids(self, ids: list[str] = None):
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
# Define the table schema
|
||||
chunks_table = Table(
|
||||
self._collection_name,
|
||||
Base.metadata,
|
||||
Column("id", TEXT, primary_key=True),
|
||||
Column("embedding", Vector(self.embedding_dimension)),
|
||||
Column("document", String, nullable=True),
|
||||
Column("metadata", JSON, nullable=True),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
try:
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e)) # noqa: T201
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete_by_ids(ids)
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
with Session(self._client._engine) as session:
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
with Session(self.client) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); "
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
self._client.delete_by_ids(ids)
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
with Session(self._client._engine) as session:
|
||||
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}"))
|
||||
with Session(self.client) as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client._engine) as session:
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from pgvecto_rs.sdk import filters
|
||||
filter_condition = filters.meta_contains(kwargs.get('filter'))
|
||||
results = self._client.search(
|
||||
top_k=int(kwargs.get('top_k')),
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get('top_k')),
|
||||
embedding=query_vector,
|
||||
filter=filter_condition
|
||||
filter=kwargs.get('filter')
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
docs = []
|
||||
for record, dis in results:
|
||||
metadata = record.meta
|
||||
metadata['score'] = dis
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if dis > score_threshold:
|
||||
doc = Document(page_content=record.text,
|
||||
metadata=metadata)
|
||||
docs.append(doc)
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Row from sqlalchemy.engine. "
|
||||
"Please 'pip install sqlalchemy>=1.4'."
|
||||
)
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
||||
# Define the base query
|
||||
sql_query = f"""
|
||||
set vectors.enable_search_growing = on;
|
||||
set vectors.enable_search_write = on;
|
||||
SELECT document, metadata, embedding <-> :embedding as distance
|
||||
FROM "{self._collection_name}"
|
||||
{filter_condition}
|
||||
ORDER BY embedding <-> :embedding
|
||||
LIMIT :k
|
||||
"""
|
||||
|
||||
# Set up the query parameters
|
||||
embedding_str = ", ".join(format(x) for x in embedding)
|
||||
embedding_str = "[" + embedding_str + "]"
|
||||
params = {"embedding": embedding_str, "k": k}
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with self.client.connect() as conn:
|
||||
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
|
||||
documents_with_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.document,
|
||||
metadata=result.metadata,
|
||||
),
|
||||
result.distance,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return documents_with_scores
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz/relyt doesn't support bm25 search
|
||||
return []
|
||||
|
||||
@ -27,6 +27,12 @@ class BaseVector(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -86,7 +86,9 @@ class Vector:
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
||||
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
||||
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
||||
)
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
@ -126,7 +128,6 @@ class Vector:
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
dim = len(self._embeddings.embed_query("hello relyt"))
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
@ -136,6 +137,31 @@ class Vector:
|
||||
password=config.get('RELYT_PASSWORD'),
|
||||
database=config.get('RELYT_DATABASE'),
|
||||
),
|
||||
group_id=self._dataset.id
|
||||
)
|
||||
elif vector_type == "pgvecto_rs":
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
||||
if self._dataset.index_struct_dict:
|
||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": 'pgvecto_rs',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
dim = len(self._embeddings.embed_query("pgvecto_rs"))
|
||||
return PGVectoRS(
|
||||
collection_name=collection_name,
|
||||
config=PgvectoRSConfig(
|
||||
host=config.get('PGVECTO_RS_HOST'),
|
||||
port=config.get('PGVECTO_RS_PORT'),
|
||||
user=config.get('PGVECTO_RS_USER'),
|
||||
password=config.get('PGVECTO_RS_PASSWORD'),
|
||||
database=config.get('PGVECTO_RS_DATABASE'),
|
||||
),
|
||||
dim=dim
|
||||
)
|
||||
else:
|
||||
|
||||
@ -29,8 +29,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
"Check the url of your file; returned status code %s"
|
||||
% r.status_code
|
||||
f"Check the url of your file; returned status code {r.status_code}"
|
||||
)
|
||||
|
||||
self.web_path = self.file_path
|
||||
@ -38,7 +37,7 @@ class WordExtractor(BaseExtractor):
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "temp_file"):
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="111" height="111" viewBox="0 0 111 111" fill="none">
|
||||
<text x="0" y="90" font-family="Verdana" font-size="85" fill="black">🔥</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 193 B |
23
api/core/tools/provider/builtin/firecrawl/firecrawl.py
Normal file
23
api/core/tools/provider/builtin/firecrawl/firecrawl.py
Normal file
@ -0,0 +1,23 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.firecrawl.tools.crawl import CrawlTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class FirecrawlProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the Crawl tool
|
||||
CrawlTool().fork_tool_runtime(
|
||||
meta={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"url": "https://example.com",
|
||||
"includes": '',
|
||||
"excludes": '',
|
||||
"limit": 1,
|
||||
"onlyMainContent": True,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
24
api/core/tools/provider/builtin/firecrawl/firecrawl.yaml
Normal file
24
api/core/tools/provider/builtin/firecrawl/firecrawl.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
identity:
|
||||
author: Richards Tu
|
||||
name: firecrawl
|
||||
label:
|
||||
en_US: Firecrawl
|
||||
zh_CN: Firecrawl
|
||||
description:
|
||||
en_US: Firecrawl API integration for web crawling and scraping.
|
||||
zh_CN: Firecrawl API 集成,用于网页爬取和数据抓取。
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
firecrawl_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Firecrawl API Key
|
||||
zh_CN: Firecrawl API 密钥
|
||||
placeholder:
|
||||
en_US: Please input your Firecrawl API key
|
||||
zh_CN: 请输入您的 Firecrawl API 密钥
|
||||
help:
|
||||
en_US: Get your Firecrawl API key from your Firecrawl account settings.
|
||||
zh_CN: 从您的 Firecrawl 账户设置中获取 Firecrawl API 密钥。
|
||||
url: https://www.firecrawl.dev/account
|
||||
50
api/core/tools/provider/builtin/firecrawl/tools/crawl.py
Normal file
50
api/core/tools/provider/builtin/firecrawl/tools/crawl.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CrawlTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
# initialize the app object with the api key
|
||||
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'])
|
||||
|
||||
options = {
|
||||
'crawlerOptions': {
|
||||
'excludes': tool_parameters.get('excludes', '').split(',') if tool_parameters.get('excludes') else [],
|
||||
'includes': tool_parameters.get('includes', '').split(',') if tool_parameters.get('includes') else [],
|
||||
'limit': tool_parameters.get('limit', 5)
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': tool_parameters.get('onlyMainContent', False)
|
||||
}
|
||||
}
|
||||
|
||||
# crawl the url
|
||||
crawl_result = app.crawl_url(
|
||||
url=tool_parameters['url'],
|
||||
params=options,
|
||||
wait_until_done=True,
|
||||
)
|
||||
|
||||
# reformat crawl result
|
||||
crawl_output = "**Crawl Result**\n\n"
|
||||
try:
|
||||
for result in crawl_result:
|
||||
crawl_output += f"**- Title:** {result.get('metadata', {}).get('title', '')}\n"
|
||||
crawl_output += f"**- Description:** {result.get('metadata', {}).get('description', '')}\n"
|
||||
crawl_output += f"**- URL:** {result.get('metadata', {}).get('ogUrl', '')}\n\n"
|
||||
crawl_output += f"**- Web Content:**\n{result.get('markdown', '')}\n\n"
|
||||
crawl_output += "---\n\n"
|
||||
except Exception as e:
|
||||
crawl_output += f"An error occurred: {str(e)}\n"
|
||||
crawl_output += f"**- Title:** {result.get('metadata', {}).get('title', '')}\n"
|
||||
crawl_output += f"**- Description:** {result.get('metadata', {}).get('description','')}\n"
|
||||
crawl_output += f"**- URL:** {result.get('metadata', {}).get('ogUrl', '')}\n\n"
|
||||
crawl_output += f"**- Web Content:**\n{result.get('markdown', '')}\n\n"
|
||||
crawl_output += "---\n\n"
|
||||
|
||||
|
||||
return self.create_text_message(crawl_output)
|
||||
78
api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml
Normal file
78
api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml
Normal file
@ -0,0 +1,78 @@
|
||||
identity:
|
||||
name: crawl
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Crawl
|
||||
zh_Hans: 爬取
|
||||
description:
|
||||
human:
|
||||
en_US: Extract data from a website by crawling through a URL.
|
||||
zh_Hans: 通过URL从网站中提取数据。
|
||||
llm: This tool initiates a web crawl to extract data from a specified URL. It allows configuring crawler options such as including or excluding URL patterns, generating alt text for images using LLMs (paid plan required), limiting the maximum number of pages to crawl, and returning only the main content of the page. The tool can return either a list of crawled documents or a list of URLs based on the provided options.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL to crawl
|
||||
zh_Hans: 要爬取的URL
|
||||
human_description:
|
||||
en_US: The URL of the website to crawl and extract data from.
|
||||
zh_Hans: 要爬取并提取数据的网站URL。
|
||||
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
|
||||
form: llm
|
||||
- name: includes
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: URL patterns to include
|
||||
zh_Hans: 要包含的URL模式
|
||||
human_description:
|
||||
en_US: Specify URL patterns to include during the crawl. Only pages matching these patterns will be crawled, you can use ',' to separate multiple patterns.
|
||||
zh_Hans: 指定爬取过程中要包含的URL模式。只有与这些模式匹配的页面才会被爬取。
|
||||
form: form
|
||||
default: ''
|
||||
- name: excludes
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: URL patterns to exclude
|
||||
zh_Hans: 要排除的URL模式
|
||||
human_description:
|
||||
en_US: Specify URL patterns to exclude during the crawl. Pages matching these patterns will be skipped, you can use ',' to separate multiple patterns.
|
||||
zh_Hans: 指定爬取过程中要排除的URL模式。匹配这些模式的页面将被跳过。
|
||||
form: form
|
||||
default: 'blog/*'
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum number of pages to crawl
|
||||
zh_Hans: 最大爬取页面数
|
||||
human_description:
|
||||
en_US: Specify the maximum number of pages to crawl. The crawler will stop after reaching this limit.
|
||||
zh_Hans: 指定要爬取的最大页面数。爬虫将在达到此限制后停止。
|
||||
form: form
|
||||
min: 1
|
||||
max: 20
|
||||
default: 5
|
||||
- name: onlyMainContent
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Only return the main content of the page
|
||||
zh_Hans: 仅返回页面的主要内容
|
||||
human_description:
|
||||
en_US: If enabled, the crawler will only return the main content of the page, excluding headers, navigation, footers, etc.
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
|
||||
form: form
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
zh_Hans: 是
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
zh_Hans: 否
|
||||
default: false
|
||||
@ -1,14 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool
|
||||
from core.tools.provider.builtin.judge0ce.tools.executeCode import ExecuteCodeTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class Judge0CEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
SubmitCodeExecutionTaskTool().fork_tool_runtime(
|
||||
ExecuteCodeTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ExecuteCodeTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
url = "https://judge0-ce.p.rapidapi.com/submissions"
|
||||
|
||||
querystring = {"base64_encoded": "false", "fields": "*"}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-RapidAPI-Key": api_key,
|
||||
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"language_id": tool_parameters['language_id'],
|
||||
"source_code": tool_parameters['source_code'],
|
||||
"stdin": tool_parameters.get('stdin', ''),
|
||||
"expected_output": tool_parameters.get('expected_output', ''),
|
||||
"additional_files": tool_parameters.get('additional_files', ''),
|
||||
}
|
||||
|
||||
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise Exception(response.text)
|
||||
|
||||
token = response.json()['token']
|
||||
|
||||
url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}"
|
||||
headers = {
|
||||
"X-RapidAPI-Key": api_key
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n"
|
||||
f"stderr: {result.get('stderr', '')}\n"
|
||||
f"compile_output: {result.get('compile_output', '')}\n"
|
||||
f"message: {result.get('message', '')}\n"
|
||||
f"status: {result['status']['description']}\n"
|
||||
f"time: {result.get('time', '')} seconds\n"
|
||||
f"memory: {result.get('memory', '')} bytes")
|
||||
else:
|
||||
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")
|
||||
@ -2,13 +2,13 @@ identity:
|
||||
name: submitCodeExecutionTask
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Submit Code Execution Task
|
||||
zh_Hans: 提交代码执行任务
|
||||
en_US: Submit Code Execution Task to Judge0 CE and get execution result.
|
||||
zh_Hans: 提交代码执行任务到 Judge0 CE 并获取执行结果。
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for submitting code execution task to Judge0 CE.
|
||||
zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。
|
||||
llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission.
|
||||
en_US: A tool for executing code and getting the result.
|
||||
zh_Hans: 一个用于执行代码并获取结果的工具。
|
||||
llm: This tool is used for executing code and getting the result.
|
||||
parameters:
|
||||
- name: source_code
|
||||
type: string
|
||||
@ -1,37 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GetExecutionResultTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}"
|
||||
headers = {
|
||||
"X-RapidAPI-Key": api_key
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return self.create_text_message(text=f"Submission details:\n"
|
||||
f"stdout: {result.get('stdout', '')}\n"
|
||||
f"stderr: {result.get('stderr', '')}\n"
|
||||
f"compile_output: {result.get('compile_output', '')}\n"
|
||||
f"message: {result.get('message', '')}\n"
|
||||
f"status: {result['status']['description']}\n"
|
||||
f"time: {result.get('time', '')} seconds\n"
|
||||
f"memory: {result.get('memory', '')} bytes")
|
||||
else:
|
||||
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")
|
||||
@ -1,23 +0,0 @@
|
||||
identity:
|
||||
name: getExecutionResult
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Get Execution Result
|
||||
zh_Hans: 获取执行结果
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。
|
||||
llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
parameters:
|
||||
- name: token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌
|
||||
human_description:
|
||||
en_US: The submission's unique token.
|
||||
zh_Hans: 提交的唯一令牌。
|
||||
llm_description: The submission's unique token. MUST get from submitCodeExecution.
|
||||
form: llm
|
||||
@ -1,49 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SubmitCodeExecutionTaskTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
source_code = tool_parameters['source_code']
|
||||
language_id = tool_parameters['language_id']
|
||||
stdin = tool_parameters.get('stdin', '')
|
||||
expected_output = tool_parameters.get('expected_output', '')
|
||||
additional_files = tool_parameters.get('additional_files', '')
|
||||
|
||||
url = "https://judge0-ce.p.rapidapi.com/submissions"
|
||||
|
||||
querystring = {"base64_encoded": "false", "fields": "*"}
|
||||
|
||||
payload = {
|
||||
"language_id": language_id,
|
||||
"source_code": source_code,
|
||||
"stdin": stdin,
|
||||
"expected_output": expected_output,
|
||||
"additional_files": additional_files,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"X-RapidAPI-Key": api_key,
|
||||
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
|
||||
}
|
||||
|
||||
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise Exception(response.text)
|
||||
|
||||
token = response.json()['token']
|
||||
|
||||
return self.create_text_message(text=token)
|
||||
@ -42,20 +42,19 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
|
||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if head_response.status_code != 200:
|
||||
return "URL returned status code {}.".format(head_response.status_code)
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
||||
# check content-type
|
||||
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
||||
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
|
||||
@ -141,10 +141,10 @@ class CodeNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in output form must be a string")
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > MAX_STRING_LENGTH:
|
||||
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
|
||||
raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters')
|
||||
|
||||
return value.replace('\x00', '')
|
||||
|
||||
@ -156,15 +156,15 @@ class CodeNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError(f"{variable} in output form must be a number")
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > MAX_NUMBER or value < MIN_NUMBER:
|
||||
raise ValueError(f'{variable} in input form is out of range.')
|
||||
raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.')
|
||||
|
||||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split('.')[1]) > MAX_PRECISION:
|
||||
raise ValueError(f'{variable} in output form has too high precision.')
|
||||
raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.')
|
||||
|
||||
return value
|
||||
|
||||
@ -271,7 +271,7 @@ class CodeNode(BaseNode):
|
||||
|
||||
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
@ -290,7 +290,7 @@ class CodeNode(BaseNode):
|
||||
|
||||
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
@ -309,7 +309,7 @@ class CodeNode(BaseNode):
|
||||
|
||||
if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}{dot}{output_name} in output form must be less than {MAX_OBJECT_ARRAY_LENGTH} characters.'
|
||||
f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.'
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
||||
@ -36,6 +36,49 @@ class EndNode(BaseNode):
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes
|
||||
:param graph: graph
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes from node data
|
||||
:param graph: graph
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
generate_nodes = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_mapping:
|
||||
node = node_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
||||
generate_nodes.append(node_id)
|
||||
|
||||
# remove duplicates
|
||||
generate_nodes = list(set(generate_nodes))
|
||||
|
||||
return generate_nodes
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
||||
@ -35,9 +35,15 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
|
||||
data: Union[None, str]
|
||||
|
||||
class Timeout(BaseModel):
|
||||
connect: int
|
||||
read: int
|
||||
write: int
|
||||
|
||||
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
|
||||
url: str
|
||||
authorization: Authorization
|
||||
headers: str
|
||||
params: str
|
||||
body: Optional[Body]
|
||||
body: Optional[Body]
|
||||
timeout: Optional[Timeout]
|
||||
@ -13,7 +13,6 @@ from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
|
||||
MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB
|
||||
READABLE_MAX_BINARY_SIZE = '10MB'
|
||||
MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB
|
||||
@ -137,14 +136,16 @@ class HttpExecutor:
|
||||
files: Union[None, dict[str, Any]]
|
||||
boundary: str
|
||||
variable_selectors: list[VariableSelector]
|
||||
timeout: HttpRequestNodeData.Timeout
|
||||
|
||||
def __init__(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
|
||||
def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, variable_pool: Optional[VariablePool] = None):
|
||||
"""
|
||||
init
|
||||
"""
|
||||
self.server_url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.authorization = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.params = {}
|
||||
self.headers = {}
|
||||
self.body = None
|
||||
@ -307,7 +308,7 @@ class HttpExecutor:
|
||||
'url': self.server_url,
|
||||
'headers': headers,
|
||||
'params': self.params,
|
||||
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
|
||||
'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
'follow_redirects': True
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
@ -12,18 +13,49 @@ from core.workflow.nodes.http_request.entities import HttpRequestNodeData
|
||||
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '300'))
|
||||
MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
|
||||
MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT),
|
||||
read=min(60, MAX_READ_TIMEOUT),
|
||||
write=min(20, MAX_WRITE_TIMEOUT))
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls) -> dict:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
"method": "get",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
},
|
||||
"body": {
|
||||
"type": "none"
|
||||
},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.dict(),
|
||||
"max_connect_timeout": MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": MAX_WRITE_TIMEOUT,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
||||
# init http executor
|
||||
http_executor = None
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool)
|
||||
http_executor = HttpExecutor(node_data=node_data,
|
||||
timeout=self._get_request_timeout(node_data),
|
||||
variable_pool=variable_pool)
|
||||
|
||||
# invoke http executor
|
||||
response = http_executor.invoke()
|
||||
@ -38,7 +70,7 @@ class HttpRequestNode(BaseNode):
|
||||
error=str(e),
|
||||
process_data=process_data
|
||||
)
|
||||
|
||||
|
||||
files = self.extract_files(http_executor.server_url, response)
|
||||
|
||||
return NodeRunResult(
|
||||
@ -54,6 +86,16 @@ class HttpRequestNode(BaseNode):
|
||||
}
|
||||
)
|
||||
|
||||
def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeData.Timeout:
|
||||
timeout = node_data.timeout
|
||||
if timeout is None:
|
||||
return HTTP_REQUEST_DEFAULT_TIMEOUT
|
||||
|
||||
timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT)
|
||||
timeout.read = min(timeout.read, MAX_READ_TIMEOUT)
|
||||
timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT)
|
||||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
@ -62,7 +104,7 @@ class HttpRequestNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data)
|
||||
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
|
||||
|
||||
variable_selectors = http_executor.variable_selectors
|
||||
|
||||
@ -84,7 +126,7 @@ class HttpRequestNode(BaseNode):
|
||||
# if not image, return directly
|
||||
if 'image' not in mimetype:
|
||||
return files
|
||||
|
||||
|
||||
if mimetype:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
|
||||
@ -79,7 +79,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': jsonable_encoder(usage),
|
||||
'topics': categories[0] if categories else ''
|
||||
}
|
||||
outputs = {
|
||||
'class_name': categories[0] if categories else ''
|
||||
|
||||
@ -1,80 +1,42 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Union
|
||||
|
||||
import boto3
|
||||
import oss2 as aliyun_s3
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask
|
||||
|
||||
from extensions.storage.aliyun_storage import AliyunStorage
|
||||
from extensions.storage.azure_storage import AzureStorage
|
||||
from extensions.storage.google_storage import GoogleStorage
|
||||
from extensions.storage.local_storage import LocalStorage
|
||||
from extensions.storage.s3_storage import S3Storage
|
||||
|
||||
|
||||
class Storage:
|
||||
def __init__(self):
|
||||
self.storage_type = None
|
||||
self.bucket_name = None
|
||||
self.client = None
|
||||
self.folder = None
|
||||
self.storage_runner = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
self.storage_type = app.config.get('STORAGE_TYPE')
|
||||
if self.storage_type == 's3':
|
||||
self.bucket_name = app.config.get('S3_BUCKET_NAME')
|
||||
self.client = boto3.client(
|
||||
's3',
|
||||
aws_secret_access_key=app.config.get('S3_SECRET_KEY'),
|
||||
aws_access_key_id=app.config.get('S3_ACCESS_KEY'),
|
||||
endpoint_url=app.config.get('S3_ENDPOINT'),
|
||||
region_name=app.config.get('S3_REGION'),
|
||||
config=Config(s3={'addressing_style': app.config.get('S3_ADDRESS_STYLE')})
|
||||
storage_type = app.config.get('STORAGE_TYPE')
|
||||
if storage_type == 's3':
|
||||
self.storage_runner = S3Storage(
|
||||
app=app
|
||||
)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
self.bucket_name = app.config.get('AZURE_BLOB_CONTAINER_NAME')
|
||||
sas_token = generate_account_sas(
|
||||
account_name=app.config.get('AZURE_BLOB_ACCOUNT_NAME'),
|
||||
account_key=app.config.get('AZURE_BLOB_ACCOUNT_KEY'),
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
|
||||
elif storage_type == 'azure-blob':
|
||||
self.storage_runner = AzureStorage(
|
||||
app=app
|
||||
)
|
||||
self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'),
|
||||
credential=sas_token)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.bucket_name = app.config.get('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.client = aliyun_s3.Bucket(
|
||||
aliyun_s3.Auth(app.config.get('ALIYUN_OSS_ACCESS_KEY'), app.config.get('ALIYUN_OSS_SECRET_KEY')),
|
||||
app.config.get('ALIYUN_OSS_ENDPOINT'),
|
||||
self.bucket_name,
|
||||
connect_timeout=30
|
||||
elif storage_type == 'aliyun-oss':
|
||||
self.storage_runner = AliyunStorage(
|
||||
app=app
|
||||
)
|
||||
elif storage_type == 'google-storage':
|
||||
self.storage_runner = GoogleStorage(
|
||||
app=app
|
||||
)
|
||||
else:
|
||||
self.folder = app.config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(self.folder):
|
||||
self.folder = os.path.join(app.root_path, self.folder)
|
||||
self.storage_runner = LocalStorage(app=app)
|
||||
|
||||
def save(self, filename, data):
|
||||
if self.storage_type == 's3':
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.put_object(filename, data)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
folder = os.path.dirname(filename)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
self.storage_runner.save(filename, data)
|
||||
|
||||
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
|
||||
if stream:
|
||||
@ -83,131 +45,19 @@ class Storage:
|
||||
return self.load_once(filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
data = obj.read()
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return data
|
||||
return self.storage_runner.load_once(filename)
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
for chunk in response['Body'].iter_chunks():
|
||||
yield chunk
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with closing(blob.download_blob()) as blob_stream:
|
||||
while chunk := blob_stream.readall(4096):
|
||||
yield chunk
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
return self.storage_runner.load_stream(filename)
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
client.download_file(self.bucket_name, filename, target_filepath)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.get_object_to_file(filename, target_filepath)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
shutil.copyfile(filename, target_filepath)
|
||||
self.storage_runner.download(filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
try:
|
||||
client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
return self.client.object_exists(filename)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
return os.path.exists(filename)
|
||||
return self.storage_runner.exists(filename)
|
||||
|
||||
def delete(self, filename):
|
||||
if self.storage_type == 's3':
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.delete_blob(filename)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.delete_object(filename)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
return self.storage_runner.delete(filename)
|
||||
|
||||
|
||||
storage = Storage()
|
||||
|
||||
48
api/extensions/storage/aliyun_storage.py
Normal file
48
api/extensions/storage/aliyun_storage.py
Normal file
@ -0,0 +1,48 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
from flask import Flask
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class AliyunStorage(BaseStorage):
|
||||
"""Implementation for aliyun storage.
|
||||
"""
|
||||
|
||||
def __init__(self, app: Flask):
|
||||
super().__init__(app)
|
||||
app_config = self.app.config
|
||||
self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.client = aliyun_s3.Bucket(
|
||||
aliyun_s3.Auth(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')),
|
||||
app_config.get('ALIYUN_OSS_ENDPOINT'),
|
||||
self.bucket_name,
|
||||
connect_timeout=30
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
data = obj.read()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.get_object_to_file(filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
return self.client.object_exists(filename)
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(filename)
|
||||
58
api/extensions/storage/azure_storage.py
Normal file
58
api/extensions/storage/azure_storage.py
Normal file
@ -0,0 +1,58 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
from flask import Flask
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class AzureStorage(BaseStorage):
|
||||
"""Implementation for azure storage.
|
||||
"""
|
||||
def __init__(self, app: Flask):
|
||||
super().__init__(app)
|
||||
app_config = self.app.config
|
||||
self.bucket_name = app_config.get('AZURE_STORAGE_CONTAINER_NAME')
|
||||
sas_token = generate_account_sas(
|
||||
account_name=app_config.get('AZURE_BLOB_ACCOUNT_NAME'),
|
||||
account_key=app_config.get('AZURE_BLOB_ACCOUNT_KEY'),
|
||||
resource_types=ResourceTypes(service=True, container=True, object=True),
|
||||
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
|
||||
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1)
|
||||
)
|
||||
self.client = BlobServiceClient(account_url=app_config.get('AZURE_BLOB_ACCOUNT_URL'),
|
||||
credential=sas_token)
|
||||
def save(self, filename, data):
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
blob = self.client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with closing(blob.download_blob()) as blob_stream:
|
||||
while chunk := blob_stream.readall(4096):
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
|
||||
def exists(self, filename):
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
|
||||
def delete(self, filename):
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.delete_blob(filename)
|
||||
38
api/extensions/storage/base_storage.py
Normal file
38
api/extensions/storage/base_storage.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Abstract interface for file storage implementations."""
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Interface for file storage.
|
||||
"""
|
||||
app = None
|
||||
|
||||
def __init__(self, app: Flask):
|
||||
self.app = app
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename, data):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def download(self, filename, target_filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, filename):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filename):
|
||||
raise NotImplementedError
|
||||
56
api/extensions/storage/google_storage.py
Normal file
56
api/extensions/storage/google_storage.py
Normal file
@ -0,0 +1,56 @@
|
||||
import base64
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
|
||||
from flask import Flask
|
||||
from google.cloud import storage as GoogleCloudStorage
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class GoogleStorage(BaseStorage):
|
||||
"""Implementation for google storage.
|
||||
"""
|
||||
def __init__(self, app: Flask):
|
||||
super().__init__(app)
|
||||
app_config = self.app.config
|
||||
self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME')
|
||||
service_account_json = base64.b64decode(app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')).decode(
|
||||
'utf-8')
|
||||
self.client = GoogleCloudStorage.Client().from_service_account_json(service_account_json)
|
||||
|
||||
def save(self, filename, data):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
blob.upload_from_file(data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
data = blob.download_as_bytes()
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
with closing(blob.open(mode='rb')) as blob_stream:
|
||||
while chunk := blob_stream.read(4096):
|
||||
yield chunk
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
|
||||
def exists(self, filename):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
return blob.exists()
|
||||
|
||||
def delete(self, filename):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
bucket.delete_blob(filename)
|
||||
88
api/extensions/storage/local_storage.py
Normal file
88
api/extensions/storage/local_storage.py
Normal file
@ -0,0 +1,88 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class LocalStorage(BaseStorage):
|
||||
"""Implementation for local storage.
|
||||
"""
|
||||
|
||||
def __init__(self, app: Flask):
|
||||
super().__init__(app)
|
||||
folder = self.app.config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(folder):
|
||||
folder = os.path.join(app.root_path, folder)
|
||||
self.folder = folder
|
||||
|
||||
def save(self, filename, data):
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
folder = os.path.dirname(filename)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
shutil.copyfile(filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
return os.path.exists(filename)
|
||||
|
||||
def delete(self, filename):
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
68
api/extensions/storage/s3_storage.py
Normal file
68
api/extensions/storage/s3_storage.py
Normal file
@ -0,0 +1,68 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class S3Storage(BaseStorage):
|
||||
"""Implementation for s3 storage.
|
||||
"""
|
||||
def __init__(self, app: Flask):
|
||||
super().__init__(app)
|
||||
app_config = self.app.config
|
||||
self.bucket_name = app_config.get('S3_BUCKET_NAME')
|
||||
self.client = boto3.client(
|
||||
's3',
|
||||
aws_secret_access_key=app_config.get('S3_SECRET_KEY'),
|
||||
aws_access_key_id=app_config.get('S3_ACCESS_KEY'),
|
||||
endpoint_url=app_config.get('S3_ENDPOINT'),
|
||||
region_name=app_config.get('S3_REGION'),
|
||||
config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')})
|
||||
)
|
||||
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read()
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response['Body'].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
with closing(self.client) as client:
|
||||
client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
def exists(self, filename):
|
||||
with closing(self.client) as client:
|
||||
try:
|
||||
client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -153,23 +153,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
# get page detail
|
||||
for page_result in page_results:
|
||||
page_id = page_result['id']
|
||||
if 'Name' in page_result['properties']:
|
||||
if len(page_result['properties']['Name']['title']) > 0:
|
||||
page_name = page_result['properties']['Name']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
elif 'title' in page_result['properties']:
|
||||
if len(page_result['properties']['title']['title']) > 0:
|
||||
page_name = page_result['properties']['title']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
elif 'Title' in page_result['properties']:
|
||||
if len(page_result['properties']['Title']['title']) > 0:
|
||||
page_name = page_result['properties']['Title']['title'][0]['plain_text']
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
else:
|
||||
page_name = 'Untitled'
|
||||
page_name = 'Untitled'
|
||||
for key in ['Name', 'title', 'Title', 'Page']:
|
||||
if key in page_result['properties']:
|
||||
if len(page_result['properties'][key].get('title', [])) > 0:
|
||||
page_name = page_result['properties'][key]['title'][0]['plain_text']
|
||||
break
|
||||
page_icon = page_result['icon']
|
||||
if page_icon:
|
||||
icon_type = page_icon['type']
|
||||
|
||||
@ -6,6 +6,7 @@ Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
"""
|
||||
@ -42,3 +45,27 @@ class CreatedFrom(Enum):
|
||||
if role.value == value:
|
||||
return role
|
||||
raise ValueError(f'invalid createdFrom value {value}')
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == 'postgresql':
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@ -2,9 +2,9 @@ import enum
|
||||
import json
|
||||
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class AccountStatus(str, enum.Enum):
|
||||
@ -22,7 +22,7 @@ class Account(UserMixin, db.Model):
|
||||
db.Index('account_email_idx', 'email')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
email = db.Column(db.String(255), nullable=False)
|
||||
password = db.Column(db.String(255), nullable=True)
|
||||
@ -128,7 +128,7 @@ class Tenant(db.Model):
|
||||
db.PrimaryKeyConstraint('id', name='tenant_pkey'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
@ -168,12 +168,12 @@ class TenantAccountJoin(db.Model):
|
||||
db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
role = db.Column(db.String(16), nullable=False, server_default='normal')
|
||||
invited_by = db.Column(UUID, nullable=True)
|
||||
invited_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -186,8 +186,8 @@ class AccountIntegrate(db.Model):
|
||||
db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
provider = db.Column(db.String(16), nullable=False)
|
||||
open_id = db.Column(db.String(255), nullable=False)
|
||||
encrypted_token = db.Column(db.String(255), nullable=False)
|
||||
@ -208,7 +208,7 @@ class InvitationCode(db.Model):
|
||||
code = db.Column(db.String(32), nullable=False)
|
||||
status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
|
||||
used_at = db.Column(db.DateTime)
|
||||
used_by_tenant_id = db.Column(UUID)
|
||||
used_by_account_id = db.Column(UUID)
|
||||
used_by_tenant_id = db.Column(StringUUID)
|
||||
used_by_account_id = db.Column(StringUUID)
|
||||
deprecated_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class APIBasedExtensionPoint(enum.Enum):
|
||||
@ -19,8 +18,8 @@ class APIBasedExtension(db.Model):
|
||||
db.Index('api_based_extension_tenant_idx', 'tenant_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
api_endpoint = db.Column(db.String(255), nullable=False)
|
||||
api_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
@ -4,10 +4,11 @@ import pickle
|
||||
from json import JSONDecodeError
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
from models.model import App, Tag, TagBinding, UploadFile
|
||||
|
||||
@ -22,8 +23,8 @@ class Dataset(db.Model):
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
provider = db.Column(db.String(255), nullable=False,
|
||||
@ -33,15 +34,15 @@ class Dataset(db.Model):
|
||||
data_source_type = db.Column(db.String(255))
|
||||
indexing_technique = db.Column(db.String(255), nullable=True)
|
||||
index_struct = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(UUID, nullable=True)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(UUID, nullable=True)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||
retrieval_model = db.Column(JSONB, nullable=True)
|
||||
|
||||
@property
|
||||
@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model):
|
||||
db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, nullable=False,
|
||||
id = db.Column(StringUUID, nullable=False,
|
||||
server_default=db.text('uuid_generate_v4()'))
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
mode = db.Column(db.String(255), nullable=False,
|
||||
server_default=db.text("'automatic'::character varying"))
|
||||
rules = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -197,19 +198,19 @@ class Document(db.Model):
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = db.Column(UUID, nullable=False,
|
||||
id = db.Column(StringUUID, nullable=False,
|
||||
server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
data_source_type = db.Column(db.String(255), nullable=False)
|
||||
data_source_info = db.Column(db.Text, nullable=True)
|
||||
dataset_process_rule_id = db.Column(UUID, nullable=True)
|
||||
dataset_process_rule_id = db.Column(StringUUID, nullable=True)
|
||||
batch = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_from = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_api_request_id = db.Column(UUID, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_api_request_id = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -234,7 +235,7 @@ class Document(db.Model):
|
||||
|
||||
# pause
|
||||
is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||
paused_by = db.Column(UUID, nullable=True)
|
||||
paused_by = db.Column(StringUUID, nullable=True)
|
||||
paused_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# error
|
||||
@ -247,11 +248,11 @@ class Document(db.Model):
|
||||
enabled = db.Column(db.Boolean, nullable=False,
|
||||
server_default=db.text('true'))
|
||||
disabled_at = db.Column(db.DateTime, nullable=True)
|
||||
disabled_by = db.Column(UUID, nullable=True)
|
||||
disabled_by = db.Column(StringUUID, nullable=True)
|
||||
archived = db.Column(db.Boolean, nullable=False,
|
||||
server_default=db.text('false'))
|
||||
archived_reason = db.Column(db.String(255), nullable=True)
|
||||
archived_by = db.Column(UUID, nullable=True)
|
||||
archived_by = db.Column(StringUUID, nullable=True)
|
||||
archived_at = db.Column(db.DateTime, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -356,11 +357,11 @@ class DocumentSegment(db.Model):
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = db.Column(UUID, nullable=False,
|
||||
id = db.Column(StringUUID, nullable=False,
|
||||
server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
document_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
@ -377,13 +378,13 @@ class DocumentSegment(db.Model):
|
||||
enabled = db.Column(db.Boolean, nullable=False,
|
||||
server_default=db.text('true'))
|
||||
disabled_at = db.Column(db.DateTime, nullable=True)
|
||||
disabled_by = db.Column(UUID, nullable=True)
|
||||
disabled_by = db.Column(StringUUID, nullable=True)
|
||||
status = db.Column(db.String(255), nullable=False,
|
||||
server_default=db.text("'waiting'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(UUID, nullable=True)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
indexing_at = db.Column(db.DateTime, nullable=True)
|
||||
@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model):
|
||||
db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
@ -438,13 +439,13 @@ class DatasetQuery(db.Model):
|
||||
db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
source = db.Column(db.String(255), nullable=False)
|
||||
source_app_id = db.Column(UUID, nullable=True)
|
||||
source_app_id = db.Column(StringUUID, nullable=True)
|
||||
created_by_role = db.Column(db.String, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model):
|
||||
db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
dataset_id = db.Column(UUID, nullable=False, unique=True)
|
||||
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
dataset_id = db.Column(StringUUID, nullable=False, unique=True)
|
||||
keyword_table = db.Column(db.Text, nullable=False)
|
||||
data_source_type = db.Column(db.String(255), nullable=False,
|
||||
server_default=db.text("'database'::character varying"))
|
||||
@ -501,7 +502,7 @@ class Embedding(db.Model):
|
||||
db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
model_name = db.Column(db.String(40), nullable=False,
|
||||
server_default=db.text("'text-embedding-ada-002'::character varying"))
|
||||
hash = db.Column(db.String(64), nullable=False)
|
||||
@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model):
|
||||
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
|
||||
@ -7,13 +7,13 @@ from typing import Optional
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
|
||||
from . import StringUUID
|
||||
from .account import Account, Tenant
|
||||
|
||||
|
||||
@ -56,15 +56,15 @@ class App(db.Model):
|
||||
db.Index('app_tenant_id_idx', 'tenant_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.String(255))
|
||||
icon_background = db.Column(db.String(255))
|
||||
app_model_config_id = db.Column(UUID, nullable=True)
|
||||
workflow_id = db.Column(UUID, nullable=True)
|
||||
app_model_config_id = db.Column(StringUUID, nullable=True)
|
||||
workflow_id = db.Column(StringUUID, nullable=True)
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
enable_site = db.Column(db.Boolean, nullable=False)
|
||||
enable_api = db.Column(db.Boolean, nullable=False)
|
||||
@ -207,8 +207,8 @@ class AppModelConfig(db.Model):
|
||||
db.Index('app_app_id_idx', 'app_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=True)
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
configs = db.Column(db.JSON, nullable=True)
|
||||
@ -430,8 +430,8 @@ class RecommendedApp(db.Model):
|
||||
db.Index('recommended_app_is_listed_idx', 'is_listed', 'language')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
description = db.Column(db.JSON, nullable=False)
|
||||
copyright = db.Column(db.String(255), nullable=False)
|
||||
privacy_policy = db.Column(db.String(255), nullable=False)
|
||||
@ -458,10 +458,10 @@ class InstalledApp(db.Model):
|
||||
db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
app_owner_tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
app_owner_tenant_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False, default=0)
|
||||
is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
@ -486,9 +486,9 @@ class Conversation(db.Model):
|
||||
db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
app_model_config_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
app_model_config_id = db.Column(StringUUID, nullable=True)
|
||||
model_provider = db.Column(db.String(255), nullable=True)
|
||||
override_model_configs = db.Column(db.Text)
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
@ -502,10 +502,10 @@ class Conversation(db.Model):
|
||||
status = db.Column(db.String(255), nullable=False)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(UUID)
|
||||
from_account_id = db.Column(UUID)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
read_at = db.Column(db.DateTime)
|
||||
read_account_id = db.Column(UUID)
|
||||
read_account_id = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -626,12 +626,12 @@ class Message(db.Model):
|
||||
db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
model_provider = db.Column(db.String(255), nullable=True)
|
||||
model_id = db.Column(db.String(255), nullable=True)
|
||||
override_model_configs = db.Column(db.Text)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
inputs = db.Column(db.JSON)
|
||||
query = db.Column(db.Text, nullable=False)
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
@ -650,12 +650,12 @@ class Message(db.Model):
|
||||
message_metadata = db.Column(db.Text)
|
||||
invoke_from = db.Column(db.String(255), nullable=True)
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(UUID)
|
||||
from_account_id = db.Column(UUID)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
workflow_run_id = db.Column(UUID)
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
|
||||
@property
|
||||
def re_sign_file_url_answer(self) -> str:
|
||||
@ -846,15 +846,15 @@ class MessageFeedback(db.Model):
|
||||
db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, nullable=False)
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, nullable=False)
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
rating = db.Column(db.String(255), nullable=False)
|
||||
content = db.Column(db.Text)
|
||||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(UUID)
|
||||
from_account_id = db.Column(UUID)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -872,15 +872,15 @@ class MessageFile(db.Model):
|
||||
db.Index('message_file_created_by_idx', 'created_by')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
belongs_to = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id = db.Column(UUID, nullable=True)
|
||||
upload_file_id = db.Column(StringUUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
@ -893,14 +893,14 @@ class MessageAnnotation(db.Model):
|
||||
db.Index('message_annotation_message_idx', 'message_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
message_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
message_id = db.Column(StringUUID, nullable=True)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -925,15 +925,15 @@ class AppAnnotationHitHistory(db.Model):
|
||||
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
annotation_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
annotation_id = db.Column(StringUUID, nullable=False)
|
||||
source = db.Column(db.Text, nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
score = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
annotation_question = db.Column(db.Text, nullable=False)
|
||||
annotation_content = db.Column(db.Text, nullable=False)
|
||||
|
||||
@ -957,13 +957,13 @@ class AppAnnotationSetting(db.Model):
|
||||
db.Index('app_annotation_settings_app_idx', 'app_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
collection_binding_id = db.Column(UUID, nullable=False)
|
||||
created_user_id = db.Column(UUID, nullable=False)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=False)
|
||||
created_user_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_user_id = db.Column(UUID, nullable=False)
|
||||
updated_user_id = db.Column(StringUUID, nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
@ -995,9 +995,9 @@ class OperationLog(db.Model):
|
||||
db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
action = db.Column(db.String(255), nullable=False)
|
||||
content = db.Column(db.JSON)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -1013,9 +1013,9 @@ class EndUser(UserMixin, db.Model):
|
||||
db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=True)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
external_user_id = db.Column(db.String(255), nullable=True)
|
||||
name = db.Column(db.String(255))
|
||||
@ -1033,8 +1033,8 @@ class Site(db.Model):
|
||||
db.Index('site_code_idx', 'code', 'status')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
title = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.String(255))
|
||||
icon_background = db.Column(db.String(255))
|
||||
@ -1074,9 +1074,9 @@ class ApiToken(db.Model):
|
||||
db.Index('api_token_tenant_idx', 'tenant_id', 'type')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=True)
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=True)
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
type = db.Column(db.String(16), nullable=False)
|
||||
token = db.Column(db.String(255), nullable=False)
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
@ -1099,8 +1099,8 @@ class UploadFile(db.Model):
|
||||
db.Index('upload_file_tenant_idx', 'tenant_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
storage_type = db.Column(db.String(255), nullable=False)
|
||||
key = db.Column(db.String(255), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
@ -1108,10 +1108,10 @@ class UploadFile(db.Model):
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
used_by = db.Column(UUID, nullable=True)
|
||||
used_by = db.Column(StringUUID, nullable=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
hash = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@ -1123,9 +1123,9 @@ class ApiRequest(db.Model):
|
||||
db.Index('api_request_token_idx', 'tenant_id', 'api_token_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
api_token_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
api_token_id = db.Column(StringUUID, nullable=False)
|
||||
path = db.Column(db.String(255), nullable=False)
|
||||
request = db.Column(db.Text, nullable=True)
|
||||
response = db.Column(db.Text, nullable=True)
|
||||
@ -1140,8 +1140,8 @@ class MessageChain(db.Model):
|
||||
db.Index('message_chain_message_id_idx', 'message_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
input = db.Column(db.Text, nullable=True)
|
||||
output = db.Column(db.Text, nullable=True)
|
||||
@ -1156,9 +1156,9 @@ class MessageAgentThought(db.Model):
|
||||
db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
message_chain_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
message_chain_id = db.Column(StringUUID, nullable=True)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
thought = db.Column(db.Text, nullable=True)
|
||||
tool = db.Column(db.Text, nullable=True)
|
||||
@ -1166,7 +1166,7 @@ class MessageAgentThought(db.Model):
|
||||
tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_input = db.Column(db.Text, nullable=True)
|
||||
observation = db.Column(db.Text, nullable=True)
|
||||
# plugin_id = db.Column(UUID, nullable=True) ## for future design
|
||||
# plugin_id = db.Column(StringUUID, nullable=True) ## for future design
|
||||
tool_process_data = db.Column(db.Text, nullable=True)
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
message_token = db.Column(db.Integer, nullable=True)
|
||||
@ -1182,7 +1182,7 @@ class MessageAgentThought(db.Model):
|
||||
currency = db.Column(db.String, nullable=True)
|
||||
latency = db.Column(db.Float, nullable=True)
|
||||
created_by_role = db.Column(db.String, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
@ -1273,15 +1273,15 @@ class DatasetRetrieverResource(db.Model):
|
||||
db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
dataset_id = db.Column(UUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_name = db.Column(db.Text, nullable=False)
|
||||
document_id = db.Column(UUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
document_name = db.Column(db.Text, nullable=False)
|
||||
data_source_type = db.Column(db.Text, nullable=False)
|
||||
segment_id = db.Column(UUID, nullable=False)
|
||||
segment_id = db.Column(StringUUID, nullable=False)
|
||||
score = db.Column(db.Float, nullable=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
hit_count = db.Column(db.Integer, nullable=True)
|
||||
@ -1289,7 +1289,7 @@ class DatasetRetrieverResource(db.Model):
|
||||
segment_position = db.Column(db.Integer, nullable=True)
|
||||
index_node_hash = db.Column(db.Text, nullable=True)
|
||||
retriever_from = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
@ -1303,11 +1303,11 @@ class Tag(db.Model):
|
||||
|
||||
TAG_TYPE_LIST = ['knowledge', 'app']
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
type = db.Column(db.String(16), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
@ -1319,9 +1319,9 @@ class TagBinding(db.Model):
|
||||
db.Index('tag_bind_tag_id_idx', 'tag_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
tag_id = db.Column(UUID, nullable=True)
|
||||
target_id = db.Column(UUID, nullable=True)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
tag_id = db.Column(StringUUID, nullable=True)
|
||||
target_id = db.Column(StringUUID, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
@ -46,8 +45,8 @@ class Provider(db.Model):
|
||||
db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
@ -93,8 +92,8 @@ class ProviderModel(db.Model):
|
||||
db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
@ -111,8 +110,8 @@ class TenantDefaultModel(db.Model):
|
||||
db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
@ -127,8 +126,8 @@ class TenantPreferredModelProvider(db.Model):
|
||||
db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
preferred_provider_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -142,10 +141,10 @@ class ProviderOrder(db.Model):
|
||||
db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
payment_product_id = db.Column(db.String(191), nullable=False)
|
||||
payment_id = db.Column(db.String(191))
|
||||
transaction_id = db.Column(db.String(191))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class DataSourceBinding(db.Model):
|
||||
@ -11,8 +12,8 @@ class DataSourceBinding(db.Model):
|
||||
db.Index('source_info_idx', "source_info", postgresql_using='gin')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
access_token = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
source_info = db.Column(JSONB, nullable=False)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class ToolProviderName(Enum):
|
||||
@ -24,8 +23,8 @@ class ToolProvider(db.Model):
|
||||
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tool_name = db.Column(db.String(40), nullable=False)
|
||||
encrypted_credentials = db.Column(db.Text, nullable=True)
|
||||
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Account, App, Tenant
|
||||
|
||||
|
||||
@ -22,11 +22,11 @@ class BuiltinToolProvider(db.Model):
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# id of the tenant
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
# who created this tool provider
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# name of the tool provider
|
||||
provider = db.Column(db.String(40), nullable=False)
|
||||
# credential of the tool provider
|
||||
@ -49,11 +49,11 @@ class PublishedAppTool(db.Model):
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# id of the app
|
||||
app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False)
|
||||
app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False)
|
||||
# who published this tool
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# description of the tool, stored in i18n format, for human
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
@ -87,7 +87,7 @@ class ApiToolProvider(db.Model):
|
||||
db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# name of the api provider
|
||||
name = db.Column(db.String(40), nullable=False)
|
||||
# icon
|
||||
@ -96,9 +96,9 @@ class ApiToolProvider(db.Model):
|
||||
schema = db.Column(db.Text, nullable=False)
|
||||
schema_type_str = db.Column(db.String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
# json format tools
|
||||
@ -140,11 +140,11 @@ class ToolModelInvoke(db.Model):
|
||||
db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# who invoke this tool
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
# provider
|
||||
provider = db.Column(db.String(40), nullable=False)
|
||||
# type
|
||||
@ -180,13 +180,13 @@ class ToolConversationVariables(db.Model):
|
||||
db.Index('conversation_id_idx', 'conversation_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# conversation user id
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
# conversation id
|
||||
conversation_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, nullable=False)
|
||||
# variables pool
|
||||
variables_str = db.Column(db.Text, nullable=False)
|
||||
|
||||
@ -208,13 +208,13 @@ class ToolFile(db.Model):
|
||||
db.Index('tool_file_conversation_id_idx', 'conversation_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
# conversation user id
|
||||
user_id = db.Column(UUID, nullable=False)
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
# conversation id
|
||||
conversation_id = db.Column(UUID, nullable=True)
|
||||
conversation_id = db.Column(StringUUID, nullable=True)
|
||||
# file key
|
||||
file_key = db.Column(db.String(255), nullable=False)
|
||||
# mime type
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.model import Message
|
||||
|
||||
|
||||
@ -11,11 +11,11 @@ class SavedMessage(db.Model):
|
||||
db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
@ -30,9 +30,9 @@ class PinnedConversation(db.Model):
|
||||
db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, nullable=False)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -2,10 +2,9 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
from models.account import Account
|
||||
|
||||
|
||||
@ -102,16 +101,16 @@ class Workflow(db.Model):
|
||||
db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
version = db.Column(db.String(255), nullable=False)
|
||||
graph = db.Column(db.Text)
|
||||
features = db.Column(db.Text)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_by = db.Column(UUID)
|
||||
updated_by = db.Column(StringUUID)
|
||||
updated_at = db.Column(db.DateTime)
|
||||
|
||||
@property
|
||||
@ -245,11 +244,11 @@ class WorkflowRun(db.Model):
|
||||
db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
sequence_number = db.Column(db.Integer, nullable=False)
|
||||
workflow_id = db.Column(UUID, nullable=False)
|
||||
workflow_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
triggered_from = db.Column(db.String(255), nullable=False)
|
||||
version = db.Column(db.String(255), nullable=False)
|
||||
@ -262,7 +261,7 @@ class WorkflowRun(db.Model):
|
||||
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
total_steps = db.Column(db.Integer, server_default=db.text('0'))
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
finished_at = db.Column(db.DateTime)
|
||||
|
||||
@ -404,12 +403,12 @@ class WorkflowNodeExecution(db.Model):
|
||||
'triggered_from', 'node_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
workflow_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
workflow_id = db.Column(StringUUID, nullable=False)
|
||||
triggered_from = db.Column(db.String(255), nullable=False)
|
||||
workflow_run_id = db.Column(UUID)
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
index = db.Column(db.Integer, nullable=False)
|
||||
predecessor_node_id = db.Column(db.String(255))
|
||||
node_id = db.Column(db.String(255), nullable=False)
|
||||
@ -424,7 +423,7 @@ class WorkflowNodeExecution(db.Model):
|
||||
execution_metadata = db.Column(db.Text)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
finished_at = db.Column(db.DateTime)
|
||||
|
||||
@property
|
||||
@ -529,14 +528,14 @@ class WorkflowAppLog(db.Model):
|
||||
db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=False)
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
workflow_id = db.Column(UUID, nullable=False)
|
||||
workflow_run_id = db.Column(UUID, nullable=False)
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
workflow_id = db.Column(StringUUID, nullable=False)
|
||||
workflow_run_id = db.Column(StringUUID, nullable=False)
|
||||
created_from = db.Column(db.String(255), nullable=False)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
beautifulsoup4==4.12.2
|
||||
flask~=3.0.1
|
||||
Flask-SQLAlchemy~=3.0.5
|
||||
SQLAlchemy~=1.4.28
|
||||
SQLAlchemy~=2.0.29
|
||||
Flask-Compress~=1.14
|
||||
flask-login~=0.6.3
|
||||
flask-migrate~=4.0.5
|
||||
@ -44,6 +44,7 @@ google-auth-httplib2==0.2.0
|
||||
google-generativeai==0.5.0
|
||||
google-search-results==2.4.2
|
||||
googleapis-common-protos==1.63.0
|
||||
google-cloud-storage==2.16.0
|
||||
replicate~=0.22.0
|
||||
websocket-client~=1.7.0
|
||||
dashscope[tokenizer]~=1.17.0
|
||||
@ -80,4 +81,5 @@ lxml==5.1.0
|
||||
xlrd~=2.0.1
|
||||
pydantic~=1.10.0
|
||||
pgvecto-rs==0.1.4
|
||||
oss2==2.15.0
|
||||
firecrawl-py==0.0.5
|
||||
oss2==2.15.0
|
||||
|
||||
@ -74,5 +74,5 @@ class BillingService:
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
|
||||
if TenantAccountRole.is_privileged_role(join.role):
|
||||
if not TenantAccountRole.is_privileged_role(join.role):
|
||||
raise ValueError('Only team owner or team admin can perform this action')
|
||||
|
||||
@ -418,9 +418,8 @@ class DocumentService:
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
|
||||
Document.indexing_status.in_(['error', 'paused'])
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,38 +1,36 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_vector(setup_mock_redis) -> None:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = MilvusVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
config=MilvusConfig(
|
||||
host='localhost',
|
||||
port=19530,
|
||||
user='root',
|
||||
password='Milvus',
|
||||
class MilvusVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MilvusVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MilvusConfig(
|
||||
host='localhost',
|
||||
port=19530,
|
||||
user='root',
|
||||
password='Milvus',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
def search_by_full_text(self):
|
||||
# milvus dos not support full text searching yet in < 2.3.x
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
def delete_by_document_id(self):
|
||||
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||
|
||||
# milvus dos not support full text searching yet in < 2.3.x
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
|
||||
def test_milvus_vector(setup_mock_redis):
|
||||
MilvusVectorTest().run_all_tests()
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TestPgvectoRSVector(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVectoRS(
|
||||
collection_name=self.collection_name.lower(),
|
||||
config=PgvectoRSConfig(
|
||||
host='localhost',
|
||||
port=5431,
|
||||
user='postgres',
|
||||
password='difyai123456',
|
||||
database='dify',
|
||||
),
|
||||
dim=128
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# pgvecto rs only support english text search, So it’s not open for now
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def delete_by_document_id(self):
|
||||
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
def test_pgvecot_rs(setup_mock_redis):
|
||||
TestPgvectoRSVector().run_all_tests()
|
||||
@ -1,40 +1,23 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
get_sample_text,
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis)-> None:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = QdrantVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
group_id=dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint='http://localhost:6333',
|
||||
api_key='difyai123456',
|
||||
class QdrantVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self.vector = QdrantVector(
|
||||
collection_name=self.collection_name,
|
||||
group_id=self.dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint='http://localhost:6333',
|
||||
api_key='difyai123456',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
|
||||
# search by full text
|
||||
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
|
||||
assert len(hits_by_full_text) >= 1
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
def test_qdrant_vector(setup_mock_redis):
|
||||
QdrantVectorTest().run_all_tests()
|
||||
|
||||
@ -1,31 +1,26 @@
|
||||
import random
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
def get_sample_text() -> str:
|
||||
def get_example_text() -> str:
|
||||
return 'test_text'
|
||||
|
||||
|
||||
def get_sample_embedding() -> list[float]:
|
||||
return [1.1, 2.2, 3.3]
|
||||
|
||||
|
||||
def get_sample_query_vector() -> list[float]:
|
||||
return get_sample_embedding()
|
||||
|
||||
|
||||
def get_sample_document(sample_dataset_id: str) -> Document:
|
||||
def get_example_document(doc_id: str) -> Document:
|
||||
doc = Document(
|
||||
page_content=get_sample_text(),
|
||||
page_content=get_example_text(),
|
||||
metadata={
|
||||
"doc_id": sample_dataset_id,
|
||||
"doc_hash": sample_dataset_id,
|
||||
"document_id": sample_dataset_id,
|
||||
"dataset_id": sample_dataset_id,
|
||||
"doc_id": doc_id,
|
||||
"doc_hash": doc_id,
|
||||
"document_id": doc_id,
|
||||
"dataset_id": doc_id,
|
||||
}
|
||||
)
|
||||
return doc
|
||||
@ -44,3 +39,63 @@ def setup_mock_redis() -> None:
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
self.example_embedding = [1.001 * i for i in range(128)]
|
||||
|
||||
def create_vector(self) -> None:
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 1
|
||||
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
|
||||
|
||||
def delete_vector(self):
|
||||
self.vector.delete()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.vector.delete_by_ids(ids=ids)
|
||||
|
||||
def add_texts(self) -> list[str]:
|
||||
batch_size = 100
|
||||
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
||||
embeddings = [self.example_embedding] * batch_size
|
||||
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||
return [doc.metadata['doc_id'] for doc in documents]
|
||||
|
||||
def text_exists(self):
|
||||
assert self.vector.text_exists(self.example_doc_id)
|
||||
|
||||
def delete_by_document_id(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
self.vector.get_ids_by_metadata_field(key='key', value='value')
|
||||
|
||||
def run_all_tests(self):
|
||||
self.create_vector()
|
||||
self.search_by_vector()
|
||||
self.search_by_full_text()
|
||||
self.text_exists()
|
||||
self.get_ids_by_metadata_field()
|
||||
self.delete_by_document_id()
|
||||
added_doc_ids = self.add_texts()
|
||||
self.delete_by_ids(added_doc_ids)
|
||||
self.delete_vector()
|
||||
|
||||
@ -1,41 +1,23 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
get_sample_text,
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_weaviate_vector(setup_mock_redis) -> None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = WeaviateVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
config=WeaviateConfig(
|
||||
endpoint='http://localhost:8080',
|
||||
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
|
||||
),
|
||||
attributes=attributes
|
||||
)
|
||||
class WeaviateVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self.vector = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint='http://localhost:8080',
|
||||
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
|
||||
),
|
||||
attributes=self.attributes
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
|
||||
# search by full text
|
||||
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
|
||||
assert len(hits_by_full_text) >= 1
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
def test_weaviate_vector(setup_mock_redis):
|
||||
WeaviateVectorTest().run_all_tests()
|
||||
|
||||
@ -114,7 +114,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
result = node.run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == 'result in output form must be a string'
|
||||
assert result.error == 'Output variable `result` must be a string'
|
||||
|
||||
def test_execute_code_output_validator_depth():
|
||||
code = '''
|
||||
|
||||
@ -53,7 +53,7 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:latest
|
||||
image: langgenius/dify-sandbox:0.1.0
|
||||
restart: always
|
||||
cap_add:
|
||||
# Why is sys_admin permission needed?
|
||||
@ -80,3 +80,4 @@ services:
|
||||
# QDRANT_API_KEY: 'difyai123456'
|
||||
# ports:
|
||||
# - "6333:6333"
|
||||
# - "6334:6334"
|
||||
|
||||
24
docker/docker-compose.pgvecto-rs.yaml
Normal file
24
docker/docker-compose.pgvecto-rs.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
version: '3'
|
||||
services:
|
||||
# The pgvecto—rs database.
|
||||
pgvecto-rs:
|
||||
image: tensorchord/pgvecto-rs:pg16-v0.2.0
|
||||
restart: always
|
||||
environment:
|
||||
PGUSER: postgres
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD: difyai123456
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB: dify
|
||||
# postgres data directory
|
||||
PGDATA: /var/lib/postgresql/data/pgdata
|
||||
volumes:
|
||||
- ./volumes/pgvectors/data:/var/lib/postgresql/data
|
||||
# uncomment to expose db(postgresql) port to host
|
||||
ports:
|
||||
- "5431:5432"
|
||||
healthcheck:
|
||||
test: [ "CMD", "pg_isready" ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -10,3 +10,4 @@ services:
|
||||
QDRANT_API_KEY: 'difyai123456'
|
||||
ports:
|
||||
- "6333:6333"
|
||||
- "6334:6334"
|
||||
|
||||
@ -2,7 +2,7 @@ version: '3'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.6.5
|
||||
image: langgenius/dify-api:0.6.6
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -70,7 +70,7 @@ services:
|
||||
# If you want to enable cross-origin support,
|
||||
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
|
||||
#
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob`, Default: `local`
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3` and `azure-blob` and `google-storage`, Default: `local`
|
||||
STORAGE_TYPE: local
|
||||
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.
|
||||
# only available when STORAGE_TYPE is `local`.
|
||||
@ -86,7 +86,10 @@ services:
|
||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
||||
# The Google storage configurations, only available when STORAGE_TYPE is `google-storage`.
|
||||
GOOGLE_STORAGE_BUCKET_NAME: 'yout-bucket-name'
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: 'your-google-service-account-json-base64-string'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||
@ -96,8 +99,12 @@ services:
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# The Qdrant clinet timeout setting.
|
||||
# The Qdrant client timeout setting.
|
||||
QDRANT_CLIENT_TIMEOUT: 20
|
||||
# The Qdrant client enable gRPC mode.
|
||||
QDRANT_GRPC_ENABLED: 'false'
|
||||
# The Qdrant server gRPC mode PORT.
|
||||
QDRANT_GRPC_PORT: 6334
|
||||
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
|
||||
# The milvus host.
|
||||
MILVUS_HOST: 127.0.0.1
|
||||
@ -109,6 +116,12 @@ services:
|
||||
MILVUS_PASSWORD: Milvus
|
||||
# The milvus tls switch.
|
||||
MILVUS_SECURE: 'false'
|
||||
# relyt configurations
|
||||
RELYT_HOST: db
|
||||
RELYT_PORT: 5432
|
||||
RELYT_USER: postgres
|
||||
RELYT_PASSWORD: difyai123456
|
||||
RELYT_DATABASE: postgres
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
@ -127,6 +140,11 @@ services:
|
||||
SENTRY_TRACES_SAMPLE_RATE: 1.0
|
||||
# The sample rate for Sentry profiles. Default: `1.0`
|
||||
SENTRY_PROFILES_SAMPLE_RATE: 1.0
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE: public
|
||||
NOTION_CLIENT_SECRET: you-client-secret
|
||||
NOTION_CLIENT_ID: you-client-id
|
||||
NOTION_INTERNAL_SECRET: you-internal-secret
|
||||
# The sandbox service endpoint.
|
||||
CODE_EXECUTION_ENDPOINT: "http://sandbox:8194"
|
||||
CODE_EXECUTION_API_KEY: dify-sandbox
|
||||
@ -150,7 +168,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.6.5
|
||||
image: langgenius/dify-api:0.6.6
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@ -193,7 +211,7 @@ services:
|
||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||
@ -205,6 +223,10 @@ services:
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# The Qdrant clinet timeout setting.
|
||||
QDRANT_CLIENT_TIMEOUT: 20
|
||||
# The Qdrant client enable gRPC mode.
|
||||
QDRANT_GRPC_ENABLED: 'false'
|
||||
# The Qdrant server gRPC mode PORT.
|
||||
QDRANT_GRPC_PORT: 6334
|
||||
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
|
||||
# The milvus host.
|
||||
MILVUS_HOST: 127.0.0.1
|
||||
@ -229,6 +251,11 @@ services:
|
||||
RELYT_USER: postgres
|
||||
RELYT_PASSWORD: difyai123456
|
||||
RELYT_DATABASE: postgres
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE: public
|
||||
NOTION_CLIENT_SECRET: you-client-secret
|
||||
NOTION_CLIENT_ID: you-client-id
|
||||
NOTION_INTERNAL_SECRET: you-internal-secret
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
@ -238,10 +265,9 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.6.5
|
||||
image: langgenius/dify-web:0.6.6
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
@ -320,7 +346,7 @@ services:
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
image: langgenius/dify-sandbox:latest
|
||||
image: langgenius/dify-sandbox:0.1.0
|
||||
restart: always
|
||||
cap_add:
|
||||
# Why is sys_admin permission needed?
|
||||
@ -346,6 +372,7 @@ services:
|
||||
# # uncomment to expose qdrant port to host
|
||||
# # ports:
|
||||
# # - "6333:6333"
|
||||
# # - "6334:6334"
|
||||
|
||||
# The nginx reverse proxy.
|
||||
# used for reverse proxying the API service and Web service.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# For production release, change this to PRODUCTION
|
||||
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
|
||||
# The deployment edition, SELF_HOSTED or CLOUD
|
||||
# The deployment edition, SELF_HOSTED
|
||||
NEXT_PUBLIC_EDITION=SELF_HOSTED
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
|
||||
@ -17,7 +17,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t
|
||||
```
|
||||
# For production release, change this to PRODUCTION
|
||||
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
|
||||
# The deployment edition, SELF_HOSTED or CLOUD
|
||||
# The deployment edition, SELF_HOSTED
|
||||
NEXT_PUBLIC_EDITION=SELF_HOSTED
|
||||
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
|
||||
# different from api or web app domain.
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import AppCard from './AppCard'
|
||||
import NewAppCard from './NewAppCard'
|
||||
import useAppsQueryState from './hooks/useAppsQueryState'
|
||||
import type { AppListResponse } from '@/models/app'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@ -54,10 +55,15 @@ const Apps = () => {
|
||||
const [activeTab, setActiveTab] = useTabSearchParams({
|
||||
defaultTab: 'all',
|
||||
})
|
||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>([])
|
||||
const [tagIDs, setTagIDs] = useState<string[]>([])
|
||||
const [keywords, setKeywords] = useState('')
|
||||
const [searchKeywords, setSearchKeywords] = useState('')
|
||||
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
|
||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
||||
const [searchKeywords, setSearchKeywords] = useState(keywords)
|
||||
const setKeywords = useCallback((keywords: string) => {
|
||||
setQuery(prev => ({ ...prev, keywords }))
|
||||
}, [setQuery])
|
||||
const setTagIDs = useCallback((tagIDs: string[]) => {
|
||||
setQuery(prev => ({ ...prev, tagIDs }))
|
||||
}, [setQuery])
|
||||
|
||||
const { data, isLoading, setSize, mutate } = useSWRInfinite(
|
||||
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
|
||||
@ -81,17 +87,18 @@ const Apps = () => {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const hasMore = data?.at(-1)?.has_more ?? true
|
||||
useEffect(() => {
|
||||
let observer: IntersectionObserver | undefined
|
||||
if (anchorRef.current) {
|
||||
observer = new IntersectionObserver((entries) => {
|
||||
if (entries[0].isIntersecting && !isLoading)
|
||||
if (entries[0].isIntersecting && !isLoading && hasMore)
|
||||
setSize((size: number) => size + 1)
|
||||
}, { rootMargin: '100px' })
|
||||
observer.observe(anchorRef.current)
|
||||
}
|
||||
return () => observer?.disconnect()
|
||||
}, [isLoading, setSize, anchorRef, mutate])
|
||||
}, [isLoading, setSize, anchorRef, mutate, hasMore])
|
||||
|
||||
const { run: handleSearch } = useDebounceFn(() => {
|
||||
setSearchKeywords(keywords)
|
||||
|
||||
53
web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts
Normal file
53
web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts
Normal file
@ -0,0 +1,53 @@
|
||||
import { type ReadonlyURLSearchParams, usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
|
||||
type AppsQuery = {
|
||||
tagIDs?: string[]
|
||||
keywords?: string
|
||||
}
|
||||
|
||||
// Parse the query parameters from the URL search string.
|
||||
function parseParams(params: ReadonlyURLSearchParams): AppsQuery {
|
||||
const tagIDs = params.get('tagIDs')?.split(';')
|
||||
const keywords = params.get('keywords') || undefined
|
||||
return { tagIDs, keywords }
|
||||
}
|
||||
|
||||
// Update the URL search string with the given query parameters.
|
||||
function updateSearchParams(query: AppsQuery, current: URLSearchParams) {
|
||||
const { tagIDs, keywords } = query || {}
|
||||
|
||||
if (tagIDs && tagIDs.length > 0)
|
||||
current.set('tagIDs', tagIDs.join(';'))
|
||||
else
|
||||
current.delete('tagIDs')
|
||||
|
||||
if (keywords)
|
||||
current.set('keywords', keywords)
|
||||
else
|
||||
current.delete('keywords')
|
||||
}
|
||||
|
||||
function useAppsQueryState() {
|
||||
const searchParams = useSearchParams()
|
||||
const [query, setQuery] = useState<AppsQuery>(() => parseParams(searchParams))
|
||||
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const syncSearchParams = useCallback((params: URLSearchParams) => {
|
||||
const search = params.toString()
|
||||
const query = search ? `?${search}` : ''
|
||||
router.push(`${pathname}${query}`)
|
||||
}, [router, pathname])
|
||||
|
||||
// Update the URL search string whenever the query changes.
|
||||
useEffect(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
updateSearchParams(query, params)
|
||||
syncSearchParams(params)
|
||||
}, [query, searchParams, syncSearchParams])
|
||||
|
||||
return useMemo(() => ({ query, setQuery }), [query])
|
||||
}
|
||||
|
||||
export default useAppsQueryState
|
||||
@ -1,248 +1,248 @@
|
||||
'use client'
|
||||
import type { FC, SVGProps } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { usePathname } from 'next/navigation'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import classNames from 'classnames'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import {
|
||||
Cog8ToothIcon,
|
||||
// CommandLineIcon,
|
||||
Squares2X2Icon,
|
||||
// eslint-disable-next-line sort-imports
|
||||
PuzzlePieceIcon,
|
||||
DocumentTextIcon,
|
||||
PaperClipIcon,
|
||||
QuestionMarkCircleIcon,
|
||||
} from '@heroicons/react/24/outline'
|
||||
import {
|
||||
Cog8ToothIcon as Cog8ToothSolidIcon,
|
||||
// CommandLineIcon as CommandLineSolidIcon,
|
||||
DocumentTextIcon as DocumentTextSolidIcon,
|
||||
} from '@heroicons/react/24/solid'
|
||||
import Link from 'next/link'
|
||||
import s from './style.module.css'
|
||||
import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets'
|
||||
import type { RelatedApp, RelatedAppResponse } from '@/models/datasets'
|
||||
import { getLocaleOnClient } from '@/i18n'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import FloatPopoverContainer from '@/app/components/base/float-popover-container'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { LanguagesSupported } from '@/i18n/language'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication'
|
||||
import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
params: { datasetId: string }
|
||||
}
|
||||
|
||||
type ILikedItemProps = {
|
||||
type?: 'plugin' | 'app'
|
||||
appStatus?: boolean
|
||||
detail: RelatedApp
|
||||
isMobile: boolean
|
||||
}
|
||||
|
||||
const LikedItem = ({
|
||||
type = 'app',
|
||||
detail,
|
||||
isMobile,
|
||||
}: ILikedItemProps) => {
|
||||
return (
|
||||
<Link className={classNames(s.itemWrapper, 'px-2', isMobile && 'justify-center')} href={`/app/${detail?.id}/overview`}>
|
||||
<div className={classNames(s.iconWrapper, 'mr-0')}>
|
||||
<AppIcon size='tiny' icon={detail?.icon} background={detail?.icon_background} />
|
||||
{type === 'app' && (
|
||||
<span className='absolute bottom-[-2px] right-[-2px] w-3.5 h-3.5 p-0.5 bg-white rounded border-[0.5px] border-[rgba(0,0,0,0.02)] shadow-sm'>
|
||||
{detail.mode === 'advanced-chat' && (
|
||||
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
|
||||
)}
|
||||
{detail.mode === 'agent-chat' && (
|
||||
<CuteRobote className='w-2.5 h-2.5 text-indigo-600' />
|
||||
)}
|
||||
{detail.mode === 'chat' && (
|
||||
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
|
||||
)}
|
||||
{detail.mode === 'completion' && (
|
||||
<AiText className='w-2.5 h-2.5 text-[#0E9384]' />
|
||||
)}
|
||||
{detail.mode === 'workflow' && (
|
||||
<Route className='w-2.5 h-2.5 text-[#f79009]' />
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{!isMobile && <div className={classNames(s.appInfo, 'ml-2')}>{detail?.name || '--'}</div>}
|
||||
</Link>
|
||||
)
|
||||
}
|
||||
|
||||
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<g clipPath="url(#clip0_4610_6951)">
|
||||
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_4610_6951">
|
||||
<rect width="16" height="16" fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
}
|
||||
|
||||
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
|
||||
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
|
||||
<path d="M5.33317 7.99984C5.33317 6.52708 6.52708 5.33317 7.99984 5.33317C8.36803 5.33317 8.6665 5.03469 8.6665 4.6665C8.6665 4.29831 8.36803 3.99984 7.99984 3.99984C5.7907 3.99984 3.99984 5.7907 3.99984 7.99984C3.99984 10.209 5.7907 11.9998 7.99984 11.9998C10.209 11.9998 11.9998 10.209 11.9998 7.99984C11.9998 7.63165 11.7014 7.33317 11.3332 7.33317C10.965 7.33317 10.6665 7.63165 10.6665 7.99984C10.6665 9.4726 9.4726 10.6665 7.99984 10.6665C6.52708 10.6665 5.33317 9.4726 5.33317 7.99984Z" fill="#155EEF" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
|
||||
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
type IExtraInfoProps = {
|
||||
isMobile: boolean
|
||||
relatedApps?: RelatedAppResponse
|
||||
}
|
||||
|
||||
const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => {
|
||||
const locale = getLocaleOnClient()
|
||||
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
|
||||
const { t } = useTranslation()
|
||||
|
||||
useEffect(() => {
|
||||
setShowTips(!isMobile)
|
||||
}, [isMobile, setShowTips])
|
||||
|
||||
return <div className='w-full flex flex-col items-center'>
|
||||
<Divider className='mt-5' />
|
||||
{(relatedApps?.data && relatedApps?.data?.length > 0) && (
|
||||
<>
|
||||
{!isMobile && <div className='w-full px-2 pb-1 pt-4 uppercase text-xs text-gray-500 font-medium'>{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}</div>}
|
||||
{isMobile && <div className={classNames(s.subTitle, 'flex items-center justify-center !px-0 gap-1')}>
|
||||
{relatedApps?.total || '--'}
|
||||
<PaperClipIcon className='h-4 w-4 text-gray-700' />
|
||||
</div>}
|
||||
{relatedApps?.data?.map((item, index) => (<LikedItem key={index} isMobile={isMobile} detail={item} />))}
|
||||
</>
|
||||
)}
|
||||
{!relatedApps?.data?.length && (
|
||||
<FloatPopoverContainer
|
||||
placement='bottom-start'
|
||||
open={isShowTips}
|
||||
toggle={toggleTips}
|
||||
isMobile={isMobile}
|
||||
triggerElement={
|
||||
<div className={classNames('h-7 w-7 inline-flex justify-center items-center rounded-lg bg-transparent', isShowTips && '!bg-gray-50')}>
|
||||
<QuestionMarkCircleIcon className='h-4 w-4 flex-shrink-0 text-gray-500' />
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className={classNames('mt-5 p-3', isMobile && 'border-[0.5px] border-gray-200 shadow-lg rounded-lg bg-white w-[160px]')}>
|
||||
<div className='flex items-center justify-start gap-2'>
|
||||
<div className={s.emptyIconDiv}>
|
||||
<Squares2X2Icon className='w-3 h-3 text-gray-500' />
|
||||
</div>
|
||||
<div className={s.emptyIconDiv}>
|
||||
<PuzzlePieceIcon className='w-3 h-3 text-gray-500' />
|
||||
</div>
|
||||
</div>
|
||||
<div className='text-xs text-gray-500 mt-2'>{t('common.datasetMenus.emptyTip')}</div>
|
||||
<a
|
||||
className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer'
|
||||
href={
|
||||
locale === LanguagesSupported[1]
|
||||
? 'https://docs.dify.ai/v/zh-hans/guides/application-design/prompt-engineering'
|
||||
: 'https://docs.dify.ai/user-guide/creating-dify-apps/prompt-engineering'
|
||||
}
|
||||
target='_blank' rel='noopener noreferrer'
|
||||
>
|
||||
<BookOpenIcon className='mr-1' />
|
||||
{t('common.datasetMenus.viewDoc')}
|
||||
</a>
|
||||
</div>
|
||||
</FloatPopoverContainer>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
|
||||
const DatasetDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
const {
|
||||
children,
|
||||
params: { datasetId },
|
||||
} = props
|
||||
const pathname = usePathname()
|
||||
const hideSideBar = /documents\/create$/.test(pathname)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
|
||||
const { data: datasetRes, error, mutate: mutateDatasetRes } = useSWR({
|
||||
url: 'fetchDatasetDetail',
|
||||
datasetId,
|
||||
}, apiParams => fetchDatasetDetail(apiParams.datasetId))
|
||||
|
||||
const { data: relatedApps } = useSWR({
|
||||
action: 'fetchDatasetRelatedApps',
|
||||
datasetId,
|
||||
}, apiParams => fetchDatasetRelatedApps(apiParams.datasetId))
|
||||
|
||||
const navigation = [
|
||||
{ name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon },
|
||||
{ name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: TargetIcon, selectedIcon: TargetSolidIcon },
|
||||
// { name: 'api & webhook', href: `/datasets/${datasetId}/api`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon },
|
||||
{ name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon },
|
||||
]
|
||||
|
||||
useEffect(() => {
|
||||
if (datasetRes)
|
||||
document.title = `${datasetRes.name || 'Dataset'} - Dify`
|
||||
}, [datasetRes])
|
||||
|
||||
const setAppSiderbarExpand = useStore(state => state.setAppSiderbarExpand)
|
||||
|
||||
useEffect(() => {
|
||||
const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand'
|
||||
const mode = isMobile ? 'collapse' : 'expand'
|
||||
setAppSiderbarExpand(isMobile ? mode : localeMode)
|
||||
}, [isMobile, setAppSiderbarExpand])
|
||||
|
||||
if (!datasetRes && !error)
|
||||
return <Loading />
|
||||
|
||||
return (
|
||||
<div className='grow flex overflow-hidden'>
|
||||
{!hideSideBar && <AppSideBar
|
||||
title={datasetRes?.name || '--'}
|
||||
icon={datasetRes?.icon || 'https://static.dify.ai/images/dataset-default-icon.png'}
|
||||
icon_background={datasetRes?.icon_background || '#F5F5F5'}
|
||||
desc={datasetRes?.description || '--'}
|
||||
navigation={navigation}
|
||||
extraInfo={mode => <ExtraInfo isMobile={mode === 'collapse'} relatedApps={relatedApps} />}
|
||||
iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'}
|
||||
/>}
|
||||
<DatasetDetailContext.Provider value={{
|
||||
indexingTechnique: datasetRes?.indexing_technique,
|
||||
dataset: datasetRes,
|
||||
mutateDatasetRes: () => mutateDatasetRes(),
|
||||
}}>
|
||||
<div className="bg-white grow overflow-hidden">{children}</div>
|
||||
</DatasetDetailContext.Provider>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetDetailLayout)
|
||||
'use client'
|
||||
import type { FC, SVGProps } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { usePathname } from 'next/navigation'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import classNames from 'classnames'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import {
|
||||
Cog8ToothIcon,
|
||||
// CommandLineIcon,
|
||||
Squares2X2Icon,
|
||||
// eslint-disable-next-line sort-imports
|
||||
PuzzlePieceIcon,
|
||||
DocumentTextIcon,
|
||||
PaperClipIcon,
|
||||
QuestionMarkCircleIcon,
|
||||
} from '@heroicons/react/24/outline'
|
||||
import {
|
||||
Cog8ToothIcon as Cog8ToothSolidIcon,
|
||||
// CommandLineIcon as CommandLineSolidIcon,
|
||||
DocumentTextIcon as DocumentTextSolidIcon,
|
||||
} from '@heroicons/react/24/solid'
|
||||
import Link from 'next/link'
|
||||
import s from './style.module.css'
|
||||
import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets'
|
||||
import type { RelatedApp, RelatedAppResponse } from '@/models/datasets'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import FloatPopoverContainer from '@/app/components/base/float-popover-container'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { LanguagesSupported } from '@/i18n/language'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication'
|
||||
import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
|
||||
import { getLocaleOnClient } from '@/i18n'
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
params: { datasetId: string }
|
||||
}
|
||||
|
||||
type ILikedItemProps = {
|
||||
type?: 'plugin' | 'app'
|
||||
appStatus?: boolean
|
||||
detail: RelatedApp
|
||||
isMobile: boolean
|
||||
}
|
||||
|
||||
const LikedItem = ({
|
||||
type = 'app',
|
||||
detail,
|
||||
isMobile,
|
||||
}: ILikedItemProps) => {
|
||||
return (
|
||||
<Link className={classNames(s.itemWrapper, 'px-2', isMobile && 'justify-center')} href={`/app/${detail?.id}/overview`}>
|
||||
<div className={classNames(s.iconWrapper, 'mr-0')}>
|
||||
<AppIcon size='tiny' icon={detail?.icon} background={detail?.icon_background} />
|
||||
{type === 'app' && (
|
||||
<span className='absolute bottom-[-2px] right-[-2px] w-3.5 h-3.5 p-0.5 bg-white rounded border-[0.5px] border-[rgba(0,0,0,0.02)] shadow-sm'>
|
||||
{detail.mode === 'advanced-chat' && (
|
||||
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
|
||||
)}
|
||||
{detail.mode === 'agent-chat' && (
|
||||
<CuteRobote className='w-2.5 h-2.5 text-indigo-600' />
|
||||
)}
|
||||
{detail.mode === 'chat' && (
|
||||
<ChatBot className='w-2.5 h-2.5 text-[#1570EF]' />
|
||||
)}
|
||||
{detail.mode === 'completion' && (
|
||||
<AiText className='w-2.5 h-2.5 text-[#0E9384]' />
|
||||
)}
|
||||
{detail.mode === 'workflow' && (
|
||||
<Route className='w-2.5 h-2.5 text-[#f79009]' />
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{!isMobile && <div className={classNames(s.appInfo, 'ml-2')}>{detail?.name || '--'}</div>}
|
||||
</Link>
|
||||
)
|
||||
}
|
||||
|
||||
const TargetIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<g clipPath="url(#clip0_4610_6951)">
|
||||
<path d="M10.6666 5.33325V3.33325L12.6666 1.33325L13.3332 2.66659L14.6666 3.33325L12.6666 5.33325H10.6666ZM10.6666 5.33325L7.9999 7.99988M14.6666 7.99992C14.6666 11.6818 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6818 1.33325 7.99992C1.33325 4.31802 4.31802 1.33325 7.99992 1.33325M11.3333 7.99992C11.3333 9.84087 9.84087 11.3333 7.99992 11.3333C6.15897 11.3333 4.66659 9.84087 4.66659 7.99992C4.66659 6.15897 6.15897 4.66659 7.99992 4.66659" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_4610_6951">
|
||||
<rect width="16" height="16" fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
}
|
||||
|
||||
const TargetSolidIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M12.7733 0.67512C12.9848 0.709447 13.1669 0.843364 13.2627 1.03504L13.83 2.16961L14.9646 2.73689C15.1563 2.83273 15.2902 3.01486 15.3245 3.22639C15.3588 3.43792 15.2894 3.65305 15.1379 3.80458L13.1379 5.80458C13.0128 5.92961 12.8433 5.99985 12.6665 5.99985H10.9426L8.47124 8.47124C8.21089 8.73159 7.78878 8.73159 7.52843 8.47124C7.26808 8.21089 7.26808 7.78878 7.52843 7.52843L9.9998 5.05707V3.33318C9.9998 3.15637 10.07 2.9868 10.1951 2.86177L12.1951 0.861774C12.3466 0.710244 12.5617 0.640794 12.7733 0.67512Z" fill="#155EEF" />
|
||||
<path d="M1.99984 7.99984C1.99984 4.68613 4.68613 1.99984 7.99984 1.99984C8.36803 1.99984 8.6665 1.70136 8.6665 1.33317C8.6665 0.964981 8.36803 0.666504 7.99984 0.666504C3.94975 0.666504 0.666504 3.94975 0.666504 7.99984C0.666504 12.0499 3.94975 15.3332 7.99984 15.3332C12.0499 15.3332 15.3332 12.0499 15.3332 7.99984C15.3332 7.63165 15.0347 7.33317 14.6665 7.33317C14.2983 7.33317 13.9998 7.63165 13.9998 7.99984C13.9998 11.3135 11.3135 13.9998 7.99984 13.9998C4.68613 13.9998 1.99984 11.3135 1.99984 7.99984Z" fill="#155EEF" />
|
||||
<path d="M5.33317 7.99984C5.33317 6.52708 6.52708 5.33317 7.99984 5.33317C8.36803 5.33317 8.6665 5.03469 8.6665 4.6665C8.6665 4.29831 8.36803 3.99984 7.99984 3.99984C5.7907 3.99984 3.99984 5.7907 3.99984 7.99984C3.99984 10.209 5.7907 11.9998 7.99984 11.9998C10.209 11.9998 11.9998 10.209 11.9998 7.99984C11.9998 7.63165 11.7014 7.33317 11.3332 7.33317C10.965 7.33317 10.6665 7.63165 10.6665 7.99984C10.6665 9.4726 9.4726 10.6665 7.99984 10.6665C6.52708 10.6665 5.33317 9.4726 5.33317 7.99984Z" fill="#155EEF" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
const BookOpenIcon = ({ className }: SVGProps<SVGElement>) => {
|
||||
return <svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
|
||||
<path opacity="0.12" d="M1 3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7V10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1Z" fill="#155EEF" />
|
||||
<path d="M6 10.5L5.94997 10.425C5.60265 9.90398 5.42899 9.64349 5.19955 9.45491C4.99643 9.28796 4.76238 9.1627 4.5108 9.0863C4.22663 9 3.91355 9 3.28741 9H2.6C2.03995 9 1.75992 9 1.54601 8.89101C1.35785 8.79513 1.20487 8.64215 1.10899 8.45399C1 8.24008 1 7.96005 1 7.4V3.1C1 2.53995 1 2.25992 1.10899 2.04601C1.20487 1.85785 1.35785 1.70487 1.54601 1.60899C1.75992 1.5 2.03995 1.5 2.6 1.5H2.8C3.9201 1.5 4.48016 1.5 4.90798 1.71799C5.28431 1.90973 5.59027 2.21569 5.78201 2.59202C6 3.01984 6 3.5799 6 4.7M6 10.5V4.7M6 10.5L6.05003 10.425C6.39735 9.90398 6.57101 9.64349 6.80045 9.45491C7.00357 9.28796 7.23762 9.1627 7.4892 9.0863C7.77337 9 8.08645 9 8.71259 9H9.4C9.96005 9 10.2401 9 10.454 8.89101C10.6422 8.79513 10.7951 8.64215 10.891 8.45399C11 8.24008 11 7.96005 11 7.4V3.1C11 2.53995 11 2.25992 10.891 2.04601C10.7951 1.85785 10.6422 1.70487 10.454 1.60899C10.2401 1.5 9.96005 1.5 9.4 1.5H9.2C8.07989 1.5 7.51984 1.5 7.09202 1.71799C6.71569 1.90973 6.40973 2.21569 6.21799 2.59202C6 3.01984 6 3.5799 6 4.7" stroke="#155EEF" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
}
|
||||
|
||||
type IExtraInfoProps = {
|
||||
isMobile: boolean
|
||||
relatedApps?: RelatedAppResponse
|
||||
}
|
||||
|
||||
const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => {
|
||||
const locale = getLocaleOnClient()
|
||||
const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile)
|
||||
const { t } = useTranslation()
|
||||
|
||||
useEffect(() => {
|
||||
setShowTips(!isMobile)
|
||||
}, [isMobile, setShowTips])
|
||||
|
||||
return <div className='w-full flex flex-col items-center'>
|
||||
<Divider className='mt-5' />
|
||||
{(relatedApps?.data && relatedApps?.data?.length > 0) && (
|
||||
<>
|
||||
{!isMobile && <div className='w-full px-2 pb-1 pt-4 uppercase text-xs text-gray-500 font-medium'>{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}</div>}
|
||||
{isMobile && <div className={classNames(s.subTitle, 'flex items-center justify-center !px-0 gap-1')}>
|
||||
{relatedApps?.total || '--'}
|
||||
<PaperClipIcon className='h-4 w-4 text-gray-700' />
|
||||
</div>}
|
||||
{relatedApps?.data?.map((item, index) => (<LikedItem key={index} isMobile={isMobile} detail={item} />))}
|
||||
</>
|
||||
)}
|
||||
{!relatedApps?.data?.length && (
|
||||
<FloatPopoverContainer
|
||||
placement='bottom-start'
|
||||
open={isShowTips}
|
||||
toggle={toggleTips}
|
||||
isMobile={isMobile}
|
||||
triggerElement={
|
||||
<div className={classNames('h-7 w-7 inline-flex justify-center items-center rounded-lg bg-transparent', isShowTips && '!bg-gray-50')}>
|
||||
<QuestionMarkCircleIcon className='h-4 w-4 flex-shrink-0 text-gray-500' />
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className={classNames('mt-5 p-3', isMobile && 'border-[0.5px] border-gray-200 shadow-lg rounded-lg bg-white w-[160px]')}>
|
||||
<div className='flex items-center justify-start gap-2'>
|
||||
<div className={s.emptyIconDiv}>
|
||||
<Squares2X2Icon className='w-3 h-3 text-gray-500' />
|
||||
</div>
|
||||
<div className={s.emptyIconDiv}>
|
||||
<PuzzlePieceIcon className='w-3 h-3 text-gray-500' />
|
||||
</div>
|
||||
</div>
|
||||
<div className='text-xs text-gray-500 mt-2'>{t('common.datasetMenus.emptyTip')}</div>
|
||||
<a
|
||||
className='inline-flex items-center text-xs text-primary-600 mt-2 cursor-pointer'
|
||||
href={
|
||||
locale === LanguagesSupported[1]
|
||||
? 'https://docs.dify.ai/v/zh-hans/guides/application-design/prompt-engineering'
|
||||
: 'https://docs.dify.ai/user-guide/creating-dify-apps/prompt-engineering'
|
||||
}
|
||||
target='_blank' rel='noopener noreferrer'
|
||||
>
|
||||
<BookOpenIcon className='mr-1' />
|
||||
{t('common.datasetMenus.viewDoc')}
|
||||
</a>
|
||||
</div>
|
||||
</FloatPopoverContainer>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
|
||||
const DatasetDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
const {
|
||||
children,
|
||||
params: { datasetId },
|
||||
} = props
|
||||
const pathname = usePathname()
|
||||
const hideSideBar = /documents\/create$/.test(pathname)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
|
||||
const { data: datasetRes, error, mutate: mutateDatasetRes } = useSWR({
|
||||
url: 'fetchDatasetDetail',
|
||||
datasetId,
|
||||
}, apiParams => fetchDatasetDetail(apiParams.datasetId))
|
||||
|
||||
const { data: relatedApps } = useSWR({
|
||||
action: 'fetchDatasetRelatedApps',
|
||||
datasetId,
|
||||
}, apiParams => fetchDatasetRelatedApps(apiParams.datasetId))
|
||||
|
||||
const navigation = [
|
||||
{ name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon },
|
||||
{ name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: TargetIcon, selectedIcon: TargetSolidIcon },
|
||||
// { name: 'api & webhook', href: `/datasets/${datasetId}/api`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon },
|
||||
{ name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon },
|
||||
]
|
||||
|
||||
useEffect(() => {
|
||||
if (datasetRes)
|
||||
document.title = `${datasetRes.name || 'Dataset'} - Dify`
|
||||
}, [datasetRes])
|
||||
|
||||
const setAppSiderbarExpand = useStore(state => state.setAppSiderbarExpand)
|
||||
|
||||
useEffect(() => {
|
||||
const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand'
|
||||
const mode = isMobile ? 'collapse' : 'expand'
|
||||
setAppSiderbarExpand(isMobile ? mode : localeMode)
|
||||
}, [isMobile, setAppSiderbarExpand])
|
||||
|
||||
if (!datasetRes && !error)
|
||||
return <Loading />
|
||||
|
||||
return (
|
||||
<div className='grow flex overflow-hidden'>
|
||||
{!hideSideBar && <AppSideBar
|
||||
title={datasetRes?.name || '--'}
|
||||
icon={datasetRes?.icon || 'https://static.dify.ai/images/dataset-default-icon.png'}
|
||||
icon_background={datasetRes?.icon_background || '#F5F5F5'}
|
||||
desc={datasetRes?.description || '--'}
|
||||
navigation={navigation}
|
||||
extraInfo={mode => <ExtraInfo isMobile={mode === 'collapse'} relatedApps={relatedApps} />}
|
||||
iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'}
|
||||
/>}
|
||||
<DatasetDetailContext.Provider value={{
|
||||
indexingTechnique: datasetRes?.indexing_technique,
|
||||
dataset: datasetRes,
|
||||
mutateDatasetRes: () => mutateDatasetRes(),
|
||||
}}>
|
||||
<div className="bg-white grow overflow-hidden">{children}</div>
|
||||
</DatasetDetailContext.Provider>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetDetailLayout)
|
||||
|
||||
@ -86,7 +86,7 @@ const ActivateForm = () => {
|
||||
timezone,
|
||||
},
|
||||
})
|
||||
setLocaleOnClient(language.startsWith('en') ? 'en' : 'zh-Hans', false)
|
||||
setLocaleOnClient(language.startsWith('en') ? 'en-US' : 'zh-Hans', false)
|
||||
setShowSuccess(true)
|
||||
}
|
||||
catch {
|
||||
|
||||
@ -362,7 +362,7 @@ const Answer: FC<IAnswerProps> = ({
|
||||
{!item.isOpeningStatement && (
|
||||
<CopyBtn
|
||||
value={content}
|
||||
className={cn(s.copyBtn, 'mr-1')}
|
||||
className='mr-1'
|
||||
/>
|
||||
)}
|
||||
{((isShowPromptLog && !isResponding) || (!item.isOpeningStatement && isShowTextToSpeech)) && (
|
||||
|
||||
@ -8,9 +8,8 @@ import { useParams } from 'next/navigation'
|
||||
import { HandThumbDownIcon, HandThumbUpIcon } from '@heroicons/react/24/outline'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { HashtagIcon } from '@heroicons/react/24/solid'
|
||||
// import PromptLog from '@/app/components/app/chat/log'
|
||||
import ResultTab from './result-tab'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import AudioBtn from '@/app/components/base/audio-btn'
|
||||
@ -26,7 +25,6 @@ import EditReplyModal from '@/app/components/app/annotation/edit-annotation-moda
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process'
|
||||
import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
|
||||
const MAX_DEPTH = 3
|
||||
|
||||
@ -293,23 +291,17 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
<div className={`flex ${contentClassName}`}>
|
||||
<div className='grow w-0'>
|
||||
{workflowProcessData && (
|
||||
<WorkflowProcessItem grayBg data={workflowProcessData} expand={workflowProcessData.expand} />
|
||||
<WorkflowProcessItem grayBg hideInfo data={workflowProcessData} expand={workflowProcessData.expand} />
|
||||
)}
|
||||
{workflowProcessData && !isError && (
|
||||
<ResultTab data={workflowProcessData} content={content} />
|
||||
)}
|
||||
{isError && (
|
||||
<div className='text-gray-400 text-sm'>{t('share.generation.batchFailed.outputPlaceholder')}</div>
|
||||
)}
|
||||
{!isError && (typeof content === 'string') && (
|
||||
{!workflowProcessData && !isError && (typeof content === 'string') && (
|
||||
<Markdown content={content} />
|
||||
)}
|
||||
{!isError && (typeof content !== 'string') && (
|
||||
<CodeEditor
|
||||
readOnly
|
||||
title={<div/>}
|
||||
language={CodeLanguage.json}
|
||||
value={content}
|
||||
isJSONStringifyBeauty
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -427,7 +419,11 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<div className='text-xs text-gray-500'>{content?.length} {t('common.unit.char')}</div>
|
||||
<div>
|
||||
{!workflowProcessData && (
|
||||
<div className='text-xs text-gray-500'>{content?.length} {t('common.unit.char')}</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user