Compare commits

..

24 Commits

Author SHA1 Message Date
66d67185a0 fix: array number validation in conversation variable 2025-08-19 14:53:30 +08:00
5f0b52c017 Fix number input in tool configure form of agent node tool item (#24154) 2025-08-19 14:26:09 +08:00
c2606f9062 fix: correct behaviour of code fix (#24152)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-19 14:18:49 +08:00
70da81d0e5 try ast-grep (#24149) 2025-08-19 13:41:52 +08:00
75199442c1 feat: Implements periodic deletion of workflow run logs that exceed t… (#23881)
Co-authored-by: shiyun.li973792 <shiyun.li@seres.cn>
Co-authored-by: 1wangshu <suewangswu@gmail.com>
Co-authored-by: Blackoutta <hyytez@gmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-08-19 09:47:34 +08:00
60cc82aff1 feat: add testcontainers based tests for feature service (#24026) 2025-08-19 09:32:47 +08:00
ebd2c8236d an example of suppress (#24136) 2025-08-19 00:21:26 +08:00
a2537ba4fd chore: translate i18n files (#24131)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-08-18 23:46:23 +08:00
a3a041ef6f feat: add delete avatar functionality with confirmation modal (#24127)
Co-authored-by: crazywoola <427733928@qq.com>
2025-08-18 21:35:20 +08:00
c0702aacac Use typing.Literal to replace str places (#24099)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-18 21:34:13 +08:00
670d479e32 Bump pyobvector to 0.2.15 (#24120) 2025-08-18 17:36:27 +08:00
ae7de7d36b fix: treat default template of code as empty (#24106)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-18 16:52:27 +08:00
ef5decc98a Chore: remove some dead code in experience-enhance-group (#24110)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
2025-08-18 16:51:43 +08:00
4445460eca fix: validate checklist before publishing workflow (#24104) 2025-08-18 16:46:22 +08:00
8288b1dcab Revert "fix pg_vector extension requires SUPERUSER, but not availabl… (#24108) 2025-08-18 16:46:15 +08:00
16d1289a0a fix pg_vector extension requires SUPERUSER, but not available on Huawei Cloud RDS (#24093)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-18 16:29:36 +08:00
ba775a1c90 chore: translate i18n files (#24102)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-08-18 16:18:45 +08:00
b0e58f9da7 Feature/improve goto anything commands (#24091) 2025-08-18 16:07:54 +08:00
531e784a92 feat: no longer enable auto upgrade when marketplace is disabled (#24… (#24101) 2025-08-18 15:57:33 +08:00
5e8fe30035 fix(ui): Optimize UI component styles and layouts (#24090) (#24092) 2025-08-18 15:56:10 +08:00
f5033c5a0e fix: no current code caused code generation show error (#24086) 2025-08-18 14:18:08 +08:00
26d7654851 chore: translate i18n files (#24081)
Co-authored-by: Stream29 <36751053+Stream29@users.noreply.github.com>
2025-08-18 12:45:17 +08:00
790a6ec203 fix: return empty list instead of raising exception for qdrant search when score_threshold is 1 (#24032) 2025-08-18 12:44:05 +08:00
de9c5f10b3 feat: enchance prompt and code (#23633)
Co-authored-by: stream <stream@dify.ai>
Co-authored-by: Stream <1542763342@qq.com>
Co-authored-by: Stream <Stream_2@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-08-18 12:29:12 +08:00
1236 changed files with 12481 additions and 61971 deletions

View File

@ -23,6 +23,9 @@ jobs:
uv run ruff check --fix-only .
# Format code
uv run ruff format .
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -8,7 +8,6 @@ on:
- "deploy/enterprise"
- "build/**"
- "release/e-*"
- "deploy/rag-dev"
tags:
- "*"

View File

@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/dev"
types:
- completed
@ -12,13 +12,12 @@ jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.conclusion == 'success'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

2
.gitignore vendored
View File

@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info
!.vscode/README.md
pyrightconfig.json
api/.vscode
# vscode Code History Extension
.history
.idea/

View File

@ -478,6 +478,13 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200

View File

@ -1,3 +1,4 @@
import os
import sys
@ -16,20 +17,20 @@ else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app

View File

@ -13,14 +13,11 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.impl.plugin import PluginInstaller
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@ -33,9 +30,7 @@ from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
@ -1343,228 +1338,3 @@ def cleanup_orphaned_draft_variables(
continue
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_datasource_oauth_client(provider, client_params):
"""
Setup datasource oauth client
"""
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
oauth_client = DatasourceOauthParamConfig(
provider=provider_name,
plugin_id=plugin_id,
system_credentials=client_params_dict,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"provider: {provider_name}", fg="green"))
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
"""
Transform datasource credentials
"""
try:
installer_manager = PluginInstaller()
plugin_migration = PluginMigration()
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for credential in notion_credentials:
tenant_id = credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
installer_manager.install_from_identifiers(
tenant_id,
[notion_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": notion_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential oauth params
access_token = credential.access_token
# notion info
notion_info = credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in firecrawl_credentials:
tenant_id = credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
installer_manager.install_from_identifiers(
tenant_id,
[firecrawl_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": firecrawl_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
api_key = credential.credentials.get("config", {}).get("api_key")
base_url = credential.credentials.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jina").all()
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in jina_credentials:
tenant_id = credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
installer_manager.install_from_identifiers(
tenant_id,
[jina_plugin_unique_identifier],
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": jina_plugin_unique_identifier,
}
],
)
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
api_key = credential.credentials.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
click.echo(
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
)
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))

View File

@ -968,6 +968,14 @@ class AccountConfig(BaseSettings):
)
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1003,5 +1011,6 @@ class FeatureConfig(
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
WorkflowLogConfig,
):
pass

View File

@ -222,28 +222,11 @@ class HostedFetchAppTemplateConfig(BaseSettings):
)
class HostedFetchPipelineTemplateConfig(BaseSettings):
"""
Configuration for fetching pipeline templates
"""
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="database",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote pipeline templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig(
# place the configs in alphabet order
HostedAnthropicConfig,
HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig,
HostedFetchPipelineTemplateConfig,
HostedMinmaxConfig,
HostedOpenAiConfig,
HostedSparkConfig,

View File

@ -3,7 +3,6 @@ from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)

View File

@ -1,3 +1,4 @@
import contextlib
import mimetypes
import os
import platform
@ -65,10 +66,8 @@ def guess_file_info_from_response(response: httpx.Response):
# Use python-magic to guess MIME type if still unknown or generic
if mimetype == "application/octet-stream" and magic is not None:
try:
with contextlib.suppress(magic.MagicException):
mimetype = magic.from_buffer(response.content[:1024], mime=True)
except magic.MagicException:
pass
extension = os.path.splitext(filename)[1]

View File

@ -87,15 +87,6 @@ from .datasets import (
upload_file,
website,
)
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_draft_variable,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers
from .explore import (

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse
@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action):
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200

View File

@ -1,3 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -10,6 +12,8 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
@ -107,6 +111,121 @@ class RuleStructuredOutputGenerateApi(Resource):
return structured_output
class InstructionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
)
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
class InstructionGenerationTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self) -> dict:
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT}
case "code":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
api.add_resource(InstructionGenerateApi, "/instruction-generate")
api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")

View File

@ -1,6 +1,4 @@
import json
from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
@ -11,8 +9,6 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
@ -22,7 +18,6 @@ from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -117,18 +112,6 @@ class DataSourceNotionListApi(Resource):
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
@ -152,49 +135,31 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
pages.append(page_info)
except Exception as e:
raise e
return {"notion_info": {**workspace_info, "pages": pages}}, 200
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
@ -202,25 +167,27 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id,
)
@ -245,12 +212,10 @@ class DataSourceNotionApi(Resource):
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],

View File

@ -279,15 +279,6 @@ class DatasetApi(Resource):
location="json",
help="Invalid external knowledge api id.",
)
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
args = parser.parse_args()
data = request.get_json()
@ -438,12 +429,10 @@ class DatasetIndexingEstimateApi(Resource):
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],

View File

@ -1,7 +1,6 @@
import json
import logging
from argparse import ArgumentTypeError
from typing import cast
from typing import Literal, cast
from flask import request
from flask_login import current_user
@ -52,7 +51,6 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -510,7 +508,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@ -664,7 +661,7 @@ class DocumentApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -761,7 +758,7 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@ -787,8 +784,6 @@ class DocumentProcessingApi(DocumentResource):
document.paused_at = None
document.is_paused = False
db.session.commit()
else:
raise InvalidActionError()
return {"result": "success"}, 200
@ -843,7 +838,7 @@ class DocumentStatusApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@ -1029,41 +1024,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
@ -1085,6 +1045,3 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@ -71,9 +71,3 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500
class PipelineNotFoundError(BaseHTTPException):
error_code = "pipeline_not_found"
description = "Pipeline not found."
code = 404

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action):
def post(self, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -1,365 +0,0 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
oauth_config = DatasourceProviderService().get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider_id}")
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
credential_id=credential_id,
)
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_config,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class DatasourceOAuthCallback(Resource):
@setup_required
def get(self, provider_id: str):
context_id = request.cookies.get("context_id") or request.args.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
oauth_client_params = datasource_provider_service.get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_client_params:
raise NotFound()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_handler = OAuthHandler()
oauth_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=datasource_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credential_id = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=context.get("credential_id"),
)
else:
datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}),
name=args.get("name", None),
)
return {"result": "success"}, 201
class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
return {"result": "success"}, 200
class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
)
return {"result": "success"}, 200
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider_id>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<path:provider_id>",
)
api.add_resource(
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceHardCodeAuthListApi,
"/auth/plugin/datasource/default-list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)
api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)
api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

View File

@ -1,57 +0,0 @@
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
credential_id=args.get("credential_id"),
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@ -1,164 +0,0 @@
import logging
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
knowledge_pipeline_publish_enabled,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@ -1,171 +0,0 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@ -1,416 +0,0 @@
import logging
from typing import Any, NoReturn
from flask import Response
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import db
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(pipeline.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(pipeline.id, node_id)
db.session.commit()
return Response("", 204)
class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
rag_pipeline_service = RagPipelineService()
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
return draft_vars
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@ -1,147 +0,0 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
dataset_name=args.get("name"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@ -39,7 +39,7 @@ class UploadFileApi(Resource):
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:

View File

@ -1,43 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)
if not pipeline:
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -261,14 +261,3 @@ def is_allow_transfer_owner(view):
abort(403)
return decorated
def knowledge_pipeline_publish_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
return decorated

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action):
def post(self, app_model: App, action: Literal["enable", "disable"]):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource):
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask import request
from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource):
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
def patch(self, tenant_id, dataset_id, action):
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
"""
Batch update document status.
Args:
tenant_id: tenant id
dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive)
action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
Returns:
dict: A dictionary with a key 'result' and a value 'success'

View File

@ -1,3 +1,5 @@
from typing import Literal
from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action):
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -97,7 +97,6 @@ class VariableEntityType(StrEnum):
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
class VariableEntity(BaseModel):
@ -114,9 +113,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -129,16 +128,6 @@ class VariableEntity(BaseModel):
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
@ -298,7 +287,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: Optional[AppAdditionalFeatures] = None
additional_features: AppAdditionalFeatures
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None

View File

@ -1,6 +1,4 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import VariableEntity
from models.workflow import Workflow
@ -22,44 +20,3 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
for item in all_second_step_variables:
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(item))
return variables

View File

@ -43,13 +43,11 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
@ -185,14 +183,6 @@ class WorkflowResponseConverter:
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
node_data = cast(DatasourceNodeData, event.node_data)
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
f"{node_data.plugin_id}/{node_data.provider_name}",
)
response.data.extras["icon"] = provider_entity.declaration.identity.icon
return response

View File

@ -1,95 +0,0 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@ -1,66 +0,0 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -1,812 +0,0 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
# run in child thread
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"pipeline": pipeline,
"workflow_id": workflow.id,
"user": user,
"application_generate_entity": application_generate_entity,
"invoke_from": invoke_from,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"streaming": streaming,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
)
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %r",
application_generate_entity.task_id,
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info["name"]
elif datasource_type == "online_document":
name = datasource_info["page"]["page_name"]
elif datasource_type == "website_crawl":
name = datasource_info["title"]
elif datasource_type == "online_drive":
name = datasource_info["key"]
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=datasource_node_data.get("credential_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("id") and datasource_info.get("type") == "folder":
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("id", ""),
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": datasource_info.get("id", ""),
"bucket": datasource_info.get("bucket", None),
}
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: Optional[str],
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: Optional[dict] = None,
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
next_page_parameters=next_page_parameters,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
last_file_key = None
for result in result_generator:
for files in result.result:
for file in files.files:
if file.type == "folder":
self._get_files_in_folder(
datasource_runtime,
file.id,
bucket,
user_id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": file.id,
"bucket": bucket,
}
)
is_truncated = files.is_truncated
next_page_parameters = files.next_page_parameters
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
)

View File

@ -1,45 +0,0 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@ -1,252 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
# init graph
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
for event in generator:
self._update_document_status(
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids:
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
"""
Update document status
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.filter(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()

View File

@ -35,7 +35,6 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod
def value_of(cls, value: str):
@ -241,38 +240,3 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class SingleLoopRunEntity(BaseModel):
"""
Single Loop Run Entity.
"""
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None

View File

@ -181,7 +181,7 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -105,14 +105,6 @@ class DifyAgentCallbackHandler(BaseModel):
self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text(
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
color=self.color,
)
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""

View File

@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
) -> None:
self.entity = entity
self.runtime = runtime
@abstractmethod
def datasource_provider_type(self) -> str:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
)

View File

@ -1,118 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_datasource_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@abstractmethod
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
"""
return datasource with given name
"""
pass
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@ -1,36 +0,0 @@
from typing import Any, Optional
from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
class DatasourceRuntime(BaseModel):
"""
Meta data of a datasource call processing
"""
tenant_id: str
datasource_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeDatasourceRuntime(DatasourceRuntime):
"""
Fake datasource runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id",
invoke_from=InvokeFrom.DEBUGGER,
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={},
runtime_parameters={},
)

View File

@ -1,247 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class DatasourceFileManager:
@staticmethod
def sign_file(datasource_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"datasources/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=present_filename,
size=len(file_binary),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(file_binary).hexdigest(),
source_url="",
created_at=datetime.now(),
)
db.session.add(upload_file)
db.session.commit()
db.session.refresh(upload_file)
return upload_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
) -> UploadFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=filename,
size=len(blob),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(blob).hexdigest(),
source_url=file_url,
created_at=datetime.now(),
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == id,
)
.first()
)
if not upload_file:
return None
blob = storage.load_once(upload_file.key)
return blob, upload_file.mime_type
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_upload_file_id(upload_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == upload_file_id,
)
.first()
)
if not upload_file:
return None, None
stream = storage.load_stream(upload_file.key)
return stream, upload_file.mime_type
# init tool_file_parser
# from core.file.datasource_file_parser import datasource_file_manager
#
# datasource_file_manager["manager"] = DatasourceFileManager

View File

@ -1,108 +0,0 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.ONLINE_DRIVE:
controller = OnlineDriveDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider_id] = controller
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)

View File

@ -1,71 +0,0 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
class DatasourceApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[DatasourceParameter]] = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("datasources", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"datasources": datasources,
"labels": self.labels,
}

View File

@ -1,23 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -1,363 +0,0 @@
import enum
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum):
"""
Enum class for datasource provider
"""
ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class DatasourceParameter(PluginParameter):
"""
Overrides type
"""
class DatasourceParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: DatasourceParameterType = Field(..., description="The type of the parameter")
description: I18nObject = Field(..., description="The description of the parameter")
@classmethod
def get_simple_instance(
cls,
name: str,
typ: DatasourceParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
type=typ,
required=required,
options=option_objs,
description=I18nObject(en_US="", zh_Hans=""),
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class DatasourceIdentity(BaseModel):
author: str = Field(..., description="The author of the datasource")
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: Optional[dict] = None
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or []
class DatasourceProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class DatasourceProviderEntity(BaseModel):
"""
Datasource provider entity
"""
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
provider_type: DatasourceProviderType
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
"""
Get an instance of DatasourceInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class DatasourceLabel(BaseModel):
"""
Datasource label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
"""
Enum class for datasource invoke
"""
RAG_PIPELINE = "rag_pipeline"
class OnlineDocumentPage(BaseModel):
"""
Online document page
"""
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
"""
Online document info
"""
workspace_id: Optional[str] = Field(None, description="The workspace id")
workspace_name: Optional[str] = Field(None, description="The workspace name")
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class OnlineDocumentPagesMessage(BaseModel):
"""
Get online document pages response
"""
result: list[OnlineDocumentInfo]
class GetOnlineDocumentPageContentRequest(BaseModel):
"""
Get online document page content request
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
type: str = Field(..., description="The type of the page")
class OnlineDocumentPageContent(BaseModel):
"""
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
class GetOnlineDocumentPageContentResponse(BaseModel):
"""
Get online document page content response
"""
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
"""
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
source_url: str = Field(..., description="The url of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
class WebSiteInfo(BaseModel):
"""
Website info
"""
status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
"""
Get website crawl response
"""
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage):
pass
#########################
# Online drive file
#########################
class OnlineDriveFile(BaseModel):
"""
Online drive file
"""
id: str = Field(..., description="The file ID")
name: str = Field(..., description="The file name")
size: int = Field(..., description="The file size")
type: str = Field(..., description="The file type: folder or file")
class OnlineDriveFileBucket(BaseModel):
"""
Online drive file bucket
"""
bucket: Optional[str] = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
"""
Get online drive file list request
"""
bucket: Optional[str] = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):
"""
Get online drive file list response
"""
result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets")
class OnlineDriveDownloadFileRequest(BaseModel):
"""
Get online drive file
"""
id: str = Field(..., description="The id of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")

View File

@ -1,37 +0,0 @@
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
class DatasourceProviderNotFoundError(ValueError):
pass
class DatasourceNotFoundError(ValueError):
pass
class DatasourceParameterValidationError(ValueError):
pass
class DatasourceProviderCredentialValidationError(ValueError):
pass
class DatasourceNotSupportedError(ValueError):
pass
class DatasourceInvokeError(ValueError):
pass
class DatasourceApiSchemaError(ValueError):
pass
class DatasourceEngineInvokeError(Exception):
meta: DatasourceInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@ -1,28 +0,0 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE

View File

@ -1,56 +0,0 @@
from typing import Any
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return LocalFileDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,73 +0,0 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_online_document_pages(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_pages(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def get_online_document_page_content(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_page_content(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -1,48 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,73 +0,0 @@
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def online_drive_browse_files(
self,
user_id: str,
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_browse_files(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def online_drive_download_file(
self,
user_id: str,
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_download_file(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -1,48 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DRIVE
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDriveDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,265 +0,0 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@ -1,124 +0,0 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional
from core.datasource.datasource_file_manager import DatasourceFileManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
class DatasourceFileMessageTransformer:
@classmethod
def transform_datasource_invoke_messages(
cls,
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
"""
for message in messages:
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
yield message
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
message.message, DatasourceMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
file = DatasourceFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
# get filename from meta
filename = meta.get("file_name", None)
mimetype = meta.get("mime_type")
if not mimetype:
mimetype = guess_type(filename)[0] or "application/octet-stream"
# if message is str, encode it to bytes
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
else:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -1,389 +0,0 @@
import re
import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@ -1,17 +0,0 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@ -1,9 +0,0 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@ -1,43 +0,0 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@ -1,35 +0,0 @@
import logging
from pathlib import Path
from typing import Any
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View File

@ -1,53 +0,0 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
WebsiteCrawlMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_website_crawl(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -1,52 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceProviderEntityWithPlugin,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -17,27 +17,3 @@ class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: Optional[dict] = None
name: str
indexing_status: str
error: Optional[str] = None
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

View File

@ -720,7 +720,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -1,15 +0,0 @@
from typing import TYPE_CHECKING, Any, cast
from core.datasource import datasource_file_manager
from core.datasource.datasource_file_manager import DatasourceFileManager
if TYPE_CHECKING:
from core.datasource.datasource_file_manager import DatasourceFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
class DatasourceFileParser:
@staticmethod
def get_datasource_file_manager() -> "DatasourceFileManager":
return cast("DatasourceFileManager", datasource_file_manager["manager"])

View File

@ -20,7 +20,6 @@ class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod
def value_of(value):

View File

@ -96,11 +96,7 @@ def to_prompt_message_content(
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)

View File

@ -1,42 +0,0 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any], credential_type: CredentialType, fallback_context: str = "provider"
) -> str:
try:
return generate_incremental_name(
[provider.name for provider in providers],
f"{credential_type.get_name()}",
)
except Exception as e:
logger.warning("Error generating next provider name for %r: %r", fallback_context, e)
return f"{credential_type.get_name()} 1"
def generate_incremental_name(
names: Sequence[str],
default_pattern: str,
) -> str:
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for name in names:
if not name:
continue
match = re.match(pattern, name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"

View File

@ -365,7 +365,6 @@ class IndexingRunner:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -1,6 +1,7 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Optional, cast
import json_repair
@ -11,6 +12,8 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
class LLMGenerator:
@ -388,3 +394,181 @@ class LLMGenerator:
except Exception as e:
logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict:
app: App | None = db.session.query(App).where(App.id == flow_id).first()
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
@staticmethod
def instruction_modify_workflow(
tenant_id: str,
flow_id: str,
node_id: str,
current: str,
instruction: str,
model_config: dict,
ideal_output: str | None,
) -> dict:
from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).where(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
raise ValueError("Workflow not found for the given app model.")
last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
try:
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
except Exception:
try:
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
"type"
]
except Exception:
node_type = "llm"
if not last_run: # Node is not executed yet
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
current=current,
error_message="",
instruction=instruction,
node_type=node_type,
ideal_output=ideal_output,
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
if not raw_agent_log:
return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent) -> dict:
return {
"status": event.status,
"error": event.error,
"data": event.data,
}
return [dict_of_event(event) for event in parsed]
last_run_dict = {
"inputs": last_run.inputs_dict,
"status": last_run.status,
"error": last_run.error,
"agent_log": agent_log_of(last_run),
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=last_run.error,
instruction=instruction,
node_type=last_run.node_type,
ideal_output=ideal_output,
)
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: dict,
last_run: dict | None,
current: str | None,
error_message: str | None,
instruction: str,
node_type: str,
ideal_output: str | None,
) -> dict:
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}"
injected_instruction = instruction
if LAST_RUN in injected_instruction:
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
if CURRENT in injected_instruction:
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
if ERROR_MESSAGE in injected_instruction:
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
match node_type:
case "llm", "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
case "code":
system_prompt = LLM_MODIFY_CODE_SYSTEM
case _:
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
prompt_messages = [
SystemPromptMessage(content=system_prompt),
UserPromptMessage(
content=json.dumps(
{
"current": current,
"last_run": last_run,
"instruction": injected_instruction,
"ideal_output": ideal_output,
}
)
),
]
model_parameters = {"temperature": 0.4}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
generated_raw = cast(str, response.message.content)
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@ -309,3 +309,116 @@ eg:
Here is the JSON schema:
{{schema}}
""" # noqa: E501
LLM_MODIFY_PROMPT_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.",
"properties": {
"current": {
"type": "string",
"description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed."
},
"last_run": {
"type": "object",
"description": "The output result from the large language model after receiving the prompt.",
},
"instruction": {
"type": "string",
"description": "User's instruction to edit the current prompt"
},
"ideal_output": {
"type": "string",
"description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal."
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
LLM_MODIFY_CODE_SYSTEM = """
Both your input and output should be in JSON format.
! Below is the schema for input content !
{
"type": "object",
"description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.",
"properties": {
"current": {
"type": "string",
"description": "The code before modification."
},
"last_run": {
"type": "object",
"description": "The result of the code.",
},
"message": {
"type": "string",
"description": "User's instruction to edit the current code"
}
}
}
! Above is the schema for input content !
! Below is the schema for output content !
{
"type": "object",
"description": "Your feedback to the user after they provide modification suggestions.",
"properties": {
"modified": {
"type": "string",
"description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change"
},
"message": {
"type": "string",
"description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user."
}
},
"required": [
"modified",
"message"
]
}
! Above is the schema for output content !
When you are modifying the code, you should remember:
- Do not use print, this not work in dify sandbox.
- Do not try dangerous call like deleting files. It's PROHIBITED.
- Do not use any library that is not built-in in with Python.
- Get inputs from the parameters of the function and have explicit type annotations.
- Write proper imports at the top of the code.
- Use return statement to return the result.
- You should return a `dict`. If you need to return a `result: str`, you should `return {"result": result}`.
Your output must strictly follow the schema format, do not output any content outside of the JSON body.
""" # noqa: E501
INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}.
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""

View File

@ -136,4 +136,3 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
DATASOURCE_TRACE = "datasource"

View File

@ -1,21 +0,0 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
class OAuthSchema(BaseModel):
"""
OAuth schema
"""
client_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, model_validator
from werkzeug.exceptions import NotFound
from core.agent.plugin_entities import AgentStrategyProviderEntity
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
@ -63,7 +62,6 @@ class PluginCategory(enum.StrEnum):
Model = "model"
Extension = "extension"
AgentStrategy = "agent-strategy"
Datasource = "datasource"
class PluginDeclaration(BaseModel):
@ -71,7 +69,6 @@ class PluginDeclaration(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
datasources: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@ -95,7 +92,6 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
datasource: Optional[DatasourceProviderEntity] = None
meta: Meta
@model_validator(mode="before")
@ -106,8 +102,6 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Tool
elif values.get("model"):
values["category"] = PluginCategory.Model
elif values.get("datasource"):
values["category"] = PluginCategory.Datasource
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
else:
@ -190,11 +184,6 @@ class ToolProviderID(GenericProviderID):
self.plugin_name = f"{self.provider_name}_tool"
class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@ -6,7 +6,6 @@ from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
@ -49,14 +48,6 @@ class PluginToolProviderEntity(BaseModel):
declaration: ToolProviderEntityWithPlugin
class PluginDatasourceProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
is_authorized: bool = False
declaration: DatasourceProviderEntityWithPlugin
class PluginAgentProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str

View File

@ -1,364 +0,0 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
WebsiteCrawlMessage,
)
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDatasourceProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from services.tools.tools_transform_service import ToolTransformService
class PluginDatasourceManager(BasePluginClient):
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginDatasourceProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
all_response = [local_file_datasource_provider] + response
for provider in all_response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.datasources:
tool.identity.provider = provider.declaration.identity.name
return all_response
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginDatasourceProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.datasources:
tool.identity.provider = provider.declaration.identity.name
return response
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
"""
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):
datasource["identity"]["provider"] = tool_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasource",
PluginDatasourceProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for datasource in response.declaration.datasources:
datasource.identity.provider = response.declaration.identity.name
return response
def get_website_crawl(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
WebsiteCrawlMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def get_online_document_pages(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
OnlineDocumentPagesMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def get_online_document_page_content(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
DatasourceMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"page": datasource_parameters.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
def online_drive_browse_files(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
OnlineDriveBrowseFilesResponse,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"request": request.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
yield from response
def online_drive_download_file(
self,
tenant_id: str,
user_id: str,
datasource_provider: str,
datasource_name: str,
credentials: dict[str, Any],
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
"""
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
DatasourceMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
"credentials": credentials,
"request": request.model_dump(),
},
},
headers={
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
yield from response
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
"""
# datasource_provider_id = GenericProviderID(provider_id)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return {
"id": "langgenius/file/file",
"plugin_id": "langgenius/file",
"provider": "file",
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
"declaration": {
"identity": {
"author": "langgenius",
"name": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"credentials_schema": [],
"provider_type": "local_file",
"datasources": [
{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"parameters": [],
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
}
],
},
}

View File

@ -4,10 +4,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginToolProviderEntity,
)
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
@ -202,36 +199,6 @@ class PluginToolManager(BasePluginClient):
return False
def validate_datasource_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the datasource
"""
tool_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": tool_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def get_runtime_parameters(
self,
tenant_id: str,

View File

@ -858,7 +858,7 @@ class ProviderManager:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -28,12 +28,10 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -51,17 +49,18 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -239,11 +238,7 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)

View File

@ -4,8 +4,8 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector):
)
try:
if self._hybrid_search_enabled:
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""")
self._client.create_fts_idx_with_fts_index_param(
table_name=self._collection_name,
fts_idx_param=FtsIndexParam(
index_name="fulltext_index_for_col_text",
field_names=["text"],
parser_type=FtsParser.IK,
),
)
except Exception as e:
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above "
+ "to support fulltext index and vector index in the same table",
e,
)
self._client.refresh_metadata([self._collection_name])
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector):
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
distance_func=l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,

View File

@ -331,6 +331,12 @@ class QdrantVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score_threshold >= 1:
# return empty list because some versions of qdrant may response with 400 bad request
# and at the same time, the score_threshold with value 1 may be valid for other vector stores
return []
filter = models.Filter(
must=[
models.FieldCondition(
@ -355,7 +361,7 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=float(kwargs.get("score_threshold") or 0.0),
score_threshold=score_threshold,
)
docs = []
for result in results:
@ -363,7 +369,6 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

View File

@ -1,38 +0,0 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum):
"""
Datasource Stream event
"""
PROCESSING = "datasource_processing"
COMPLETED = "datasource_completed"
ERROR = "datasource_error"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")

View File

@ -11,7 +11,6 @@ class NotionInfo(BaseModel):
Notion import info.
"""
credential_id: Optional[str] = None
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str

View File

@ -171,7 +171,6 @@ class ExtractProcessor:
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,
tenant_id=extract_setting.notion_info.tenant_id,
credential_id=extract_setting.notion_info.credential_id,
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:

View File

@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor):
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = "crawl",
only_main_content: bool = True,
):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id

View File

@ -8,14 +8,7 @@ class JinaReaderWebExtractor(BaseExtractor):
Crawl and scrape websites and return content in clean llm-ready markdown.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = "crawl",
only_main_content: bool = False,
):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id

View File

@ -10,7 +10,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from services.datasource_provider_service import DatasourceProviderService
from models.source import DataSourceOauthBinding
logger = logging.getLogger(__name__)
@ -37,18 +37,16 @@ class NotionExtractor(BaseExtractor):
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
credential_id: Optional[str] = None,
):
self._notion_access_token = None
self._document_model = document_model
self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
self._credential_id = credential_id
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
if not self._notion_access_token:
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
@ -368,18 +366,23 @@ class NotionExtractor(BaseExtractor):
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
# get credential from tenant_id and credential_id
if not credential_id:
raise Exception(f"No credential id found for tenant {tenant_id}")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
.first()
)
if not credential:
raise Exception(f"No notion credential found for tenant {tenant_id} and credential {credential_id}")
return cast(str, credential["integration_secret"])
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)

View File

@ -16,14 +16,7 @@ class WaterCrawlWebExtractor(BaseExtractor):
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = "crawl",
only_main_content: bool = True,
):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id

View File

@ -13,5 +13,3 @@ class MetadataDataSource(Enum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"

View File

@ -1,8 +1,7 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, Optional
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
@ -14,7 +13,6 @@ from core.rag.splitter.fixed_text_splitter import (
)
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
class BaseIndexProcessor(ABC):
@ -35,14 +33,6 @@ class BaseIndexProcessor(ABC):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
raise NotImplementedError
@abstractmethod
def retrieve(
self,

View File

@ -1,22 +1,19 @@
"""Paragraph index processor."""
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Optional
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, GeneralStructureChunk
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -130,34 +127,3 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
paragraph = GeneralStructureChunk(**chunks)
documents = []
for content in paragraph.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks)
preview = []
for content in paragraph.general_chunks:
preview.append({"content": content})
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}

View File

@ -1,24 +1,20 @@
"""Paragraph index processor."""
import json
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -206,60 +202,3 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
parent_childs = ParentChildStructureChunk(**chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(parent_child.parent_content),
}
child_documents = []
for child in parent_child.child_contents:
child_metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(child),
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
documents.append(doc)
if documents:
# update document parent mode
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="hierarchical",
rules=json.dumps(
{
"parent_mode": parent_childs.parent_mode,
}
),
created_by=document.created_by,
)
db.session.add(dataset_process_rule)
db.session.flush()
document.dataset_process_rule_id = dataset_process_rule.id
db.session.commit()
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}

View File

@ -4,8 +4,7 @@ import logging
import re
import threading
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Optional
import pandas as pd
from flask import Flask, current_app
@ -15,15 +14,13 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -164,36 +161,6 @@ class QAIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
qa_chunks = QAStructureChunk(**chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(qa_chunk.question),
"answer": qa_chunk.answer,
}
doc = Document(page_content=qa_chunk.question, metadata=metadata)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
else:
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
qa_chunks = QAStructureChunk(**chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {"qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks)}
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():

Some files were not shown because too many files have changed in this diff Show More