Compare commits

..

16 Commits

Author SHA1 Message Date
594906c1ff fix: MD5 and 8‑hex Suffix Collision Risk 2025-09-24 17:01:23 +08:00
80f8245f2e fix(api): sync api/uv.lock with main to resolve binary diff 2025-09-24 12:00:50 +08:00
a12b437c16 fix(api): sync api/uv.lock with main to resolve binary diff 2025-09-24 11:58:07 +08:00
12de554313 fix: add index initialization checks, improve batch vector operations and search, ensure robust exception handling. 2025-09-23 16:41:46 +08:00
1f36c0c1c5 sync docker compose files with main branch 2025-09-23 00:12:54 +08:00
8b9297563c fix 2025-09-23 00:03:31 +08:00
1cbe9eedb6 fix(pinecone): normalize index names and sanitize metadata to meet API constraints 2025-09-20 02:56:53 +08:00
90fc5a1f12 pipecone 2025-09-16 08:57:46 +08:00
41dfdf1ac0 fix:score threshold 2025-09-01 16:34:17 +08:00
dd7de74aa6 修复top-k硬编码回退问题 2025-09-01 14:27:43 +08:00
f11131f8b5 fix: basepath did not read from the environment variable (#24870) 2025-09-01 13:50:33 +08:00
2e6e414a9e the conversion OAuthGrantType(parsed_args["grant_type"]) can raise ValueError for invalid values which is not caught and will produce a 500 (#24854) 2025-09-01 10:05:54 +08:00
c45d676477 remove duplicated authorization header handling and bearer should be case-insensitive (#24852) 2025-09-01 10:05:19 +08:00
b8d8dddd5a example of decorator typing (#24857)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-01 10:04:24 +08:00
c45c22b1b2 fix translation of all oauth.ts (#24855) 2025-09-01 10:04:05 +08:00
3d57a9ccdc Fix never hit (!code || code.length === 0) (#24860) 2025-09-01 09:45:07 +08:00
420 changed files with 14023 additions and 22397 deletions

View File

@ -156,7 +156,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `pinecone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -361,6 +361,17 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Pinecone configuration, only available when VECTOR_STORE is `pinecone`
PINECONE_API_KEY=your-pinecone-api-key
PINECONE_ENVIRONMENT=your-pinecone-environment
PINECONE_INDEX_NAME=dify-index
PINECONE_CLIENT_TIMEOUT=30
PINECONE_BATCH_SIZE=100
PINECONE_METRIC=cosine
PINECONE_PODS=1
PINECONE_POD_TYPE=s1
# Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE=
# If using SendGrid, use the 'from' field for authentication if necessary.
@ -460,16 +471,6 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
# Seconds of idle time before scaling down workers (default: 5.0)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)

View File

@ -1,112 +0,0 @@
[importlinter]
root_packages =
core
configs
controllers
models
tasks
services
[importlinter:contract:workflow]
name = Workflow
type=layers
layers =
graph_engine
graph_events
graph
nodes
node_events
entities
containers =
core.workflow
ignore_imports =
core.workflow.nodes.base.node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.node_factory -> core.workflow.graph
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
graph_engine
response_coordinator
containers =
core.workflow.graph_engine
[importlinter:contract:worker]
name = Worker
type = layers
layers =
graph_engine
worker
containers =
core.workflow.graph_engine
[importlinter:contract:graph-engine-architecture]
name = Graph Engine Architecture
type = layers
layers =
graph_engine
orchestration
command_processing
event_management
error_handling
graph_traversal
state_management
worker_management
domain
containers =
core.workflow.graph_engine
[importlinter:contract:domain-isolation]
name = Domain Model Isolation
type = forbidden
source_modules =
core.workflow.graph_engine.domain
forbidden_modules =
core.workflow.graph_engine.worker_management
core.workflow.graph_engine.command_channels
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:error-handling-strategies]
name = Error Handling Strategies
type = independence
modules =
core.workflow.graph_engine.error_handling.abort_strategy
core.workflow.graph_engine.error_handling.retry_strategy
core.workflow.graph_engine.error_handling.fail_branch_strategy
core.workflow.graph_engine.error_handling.default_value_strategy
[importlinter:contract:graph-traversal-components]
name = Graph Traversal Components
type = layers
layers =
edge_processor
skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@ -13,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
@ -30,7 +31,6 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from models.provider_ids import ToolProviderID
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs

View File

@ -529,28 +529,6 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
description="Maximum number of workers per GraphEngine instance",
default=10,
)
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
description="Queue depth threshold that triggers worker scale up",
default=3,
)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
description="Seconds of idle time before scaling down workers",
default=5.0,
ge=0.1,
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""

View File

@ -35,6 +35,7 @@ from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.pinecone_config import PineconeConfig
from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig
from .vdb.tablestore_config import TableStoreConfig
@ -331,6 +332,7 @@ class MiddlewareConfig(
PGVectorConfig,
VastbaseVectorConfig,
PGVectoRSConfig,
PineconeConfig,
QdrantConfig,
RelytConfig,
TencentVectorDBConfig,

View File

@ -0,0 +1,41 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class PineconeConfig(BaseSettings):
"""
Configuration settings for Pinecone vector database
"""
PINECONE_API_KEY: Optional[str] = Field(
description="API key for authenticating with Pinecone service",
default=None,
)
PINECONE_ENVIRONMENT: Optional[str] = Field(
description="Pinecone environment (e.g., 'us-west1-gcp', 'us-east-1-aws')",
default=None,
)
PINECONE_INDEX_NAME: Optional[str] = Field(
description="Default Pinecone index name",
default=None,
)
PINECONE_CLIENT_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for Pinecone client operations (default is 30 seconds)",
default=30,
)
PINECONE_BATCH_SIZE: PositiveInt = Field(
description="Batch size for Pinecone operations (default is 100)",
default=100,
)
PINECONE_METRIC: str = Field(
description="Distance metric for Pinecone index (cosine, euclidean, dotproduct)",
default="cosine",
)

View File

@ -16,10 +16,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from models import App
from services.workflow_service import WorkflowService
class RuleGenerateApi(Resource):
@ -138,6 +135,9 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400

View File

@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -414,12 +413,7 @@ class WorkflowTaskStopApi(Resource):
if not current_user.is_editor:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required

View File

@ -17,11 +17,10 @@ from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models import App, AppMode, db
from models.account import Account
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService

View File

@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
if not request.headers.get("Authorization"):
raise BadRequest("Authorization is required")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.split(" ")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0]
if token_type != "Bearer":
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1]
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
grant_type = OAuthGrantType(parsed_args["grant_type"])
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token,
}
)
else:
raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource):

View File

@ -19,6 +19,7 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -30,7 +31,6 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -660,6 +660,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.PINECONE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -711,6 +712,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.PINECONE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (

View File

@ -20,7 +20,6 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user
from models.model import AppMode, InstalledApp
@ -79,11 +78,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}

View File

@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

View File

@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
from models import db as global_db
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -26,8 +26,7 @@ from core.errors.error import (
)
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
@ -263,12 +262,7 @@ class WorkflowTaskStopApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {"result": "success"}

View File

@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService

View File

@ -21,7 +21,6 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@ -111,12 +110,7 @@ class WorkflowTaskStopApi(WebApiResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {"result": "success"}

View File

@ -4,8 +4,8 @@ from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from models.provider_ids import ModelProviderID
class ModelConfigManager:

View File

@ -1,11 +1,11 @@
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -23,17 +23,16 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
@ -77,32 +76,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
graph_runtime_state.variable_pool = variable_pool
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
graph_runtime_state.variable_pool = variable_pool
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
@ -153,27 +144,16 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
db.session.close()
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -184,11 +164,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
variable_pool=variable_pool,
)
generator = workflow_entry.run()
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -31,9 +31,14 @@ from core.app.entities.queue_entities import (
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -60,8 +65,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
@ -390,7 +395,9 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -435,6 +442,32 @@ class AdvancedChatAppGenerateTaskPipeline:
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -726,6 +759,8 @@ class AdvancedChatAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -773,6 +808,8 @@ class AdvancedChatAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -785,6 +822,17 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
@ -808,6 +856,11 @@ class AdvancedChatAppGenerateTaskPipeline:
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.enums import NodeType
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,

View File

@ -126,21 +126,6 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from sqlalchemy.orm import Session
@ -16,9 +16,14 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
@ -31,16 +36,18 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
@ -164,10 +171,11 @@ class WorkflowResponseConverter:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
return response
@ -175,7 +183,11 @@ class WorkflowResponseConverter:
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@ -209,6 +221,9 @@ class WorkflowResponseConverter:
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@ -260,6 +275,50 @@ class WorkflowResponseConverter:
),
)
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
@ -274,11 +333,13 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -296,10 +357,15 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -318,7 +384,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
@ -332,6 +398,8 @@ class WorkflowResponseConverter:
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -345,7 +413,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -369,7 +437,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
index=event.index,
pre_loop_output=event.output,
created_at=int(time.time()),
@ -377,6 +445,7 @@ class WorkflowResponseConverter:
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -394,7 +463,7 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
title=event.node_data.title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -53,6 +53,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None]: ...
@overload
@ -66,6 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...
@overload
@ -79,6 +81,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
@ -91,6 +94,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -182,6 +186,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
@ -195,6 +200,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@ -208,6 +214,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -230,6 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
@ -426,7 +434,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
@ -456,6 +474,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,

View File

@ -1,7 +1,7 @@
import logging
import time
from typing import cast
from typing import Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -9,14 +9,13 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from models.enums import UserFrom
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
@ -32,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
) -> None:
@ -41,6 +41,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
@ -51,30 +52,24 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
)
else:
inputs = self.application_generate_entity.inputs
@ -97,26 +92,15 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -127,11 +111,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)
generator = workflow_entry.run()
generator = workflow_entry.run(callbacks=workflow_callbacks)
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Optional, Union
from typing import Any, Optional, Union
from sqlalchemy.orm import Session
@ -14,7 +14,6 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -26,9 +25,14 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -53,8 +57,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -346,7 +350,9 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -365,6 +371,32 @@ class WorkflowAppGenerateTaskPipeline:
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -586,6 +618,8 @@ class WorkflowAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -600,7 +634,7 @@ class WorkflowAppGenerateTaskPipeline:
def _dispatch_event(
self,
event: AppQueueEvent,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
@ -627,6 +661,8 @@ class WorkflowAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -639,6 +675,17 @@ class WorkflowAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(

View File

@ -2,7 +2,6 @@ from collections.abc import Mapping
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -14,9 +13,14 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -24,39 +28,42 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import (
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeInIterationFailedEvent,
NodeInLoopFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
@ -72,14 +79,7 @@ class WorkflowBasedAppRunner:
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
) -> Graph:
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
@ -91,28 +91,8 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
@ -124,7 +104,6 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
node_id: str,
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
@ -149,9 +128,7 @@ class WorkflowBasedAppRunner:
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get("iteration_id", "") == node_id
or node.get("id") == f"{node_id}start"
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
]
graph_config["nodes"] = node_configs
@ -168,25 +145,8 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
@ -241,7 +201,6 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
node_id: str,
user_inputs: dict,
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
@ -266,9 +225,7 @@ class WorkflowBasedAppRunner:
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get("loop_id", "") == node_id
or node.get("id") == f"{node_id}start"
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
@ -285,25 +242,8 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
@ -370,21 +310,29 @@ class WorkflowBasedAppRunner:
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
@ -395,8 +343,6 @@ class WorkflowBasedAppRunner:
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunStartedEvent):
@ -404,30 +350,44 @@ class WorkflowBasedAppRunner:
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -436,18 +396,34 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -458,21 +434,93 @@ class WorkflowBasedAppRunner:
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeInLoopFailedEvent):
self._publish_event(
QueueNodeInLoopFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -485,10 +533,10 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunAgentLogEvent):
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.message_id,
id=event.id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
@ -499,13 +547,51 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
)
)
elif isinstance(event, NodeRunIterationStartedEvent):
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -513,41 +599,55 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, NodeRunIterationNextEvent):
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunLoopStartedEvent):
elif isinstance(event, LoopRunStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -555,32 +655,42 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, NodeRunLoopNextEvent):
elif isinstance(event, LoopRunNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
)
)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@ -11,7 +11,7 @@ from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(StrEnum):
class InvokeFrom(Enum):
"""
Invoke From.
"""

View File

@ -7,9 +7,11 @@ from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(StrEnum):
@ -41,6 +43,9 @@ class QueueEvent(StrEnum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
@ -75,7 +80,15 @@ class QueueIterationStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
@ -95,9 +108,20 @@ class QueueIterationNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
class QueueIterationCompletedEvent(AppQueueEvent):
@ -110,7 +134,15 @@ class QueueIterationCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
@ -131,7 +163,7 @@ class QueueLoopStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
@ -159,7 +191,7 @@ class QueueLoopNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
@ -172,6 +204,7 @@ class QueueLoopNextEvent(AppQueueEvent):
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
class QueueLoopCompletedEvent(AppQueueEvent):
@ -184,7 +217,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
@ -331,24 +364,27 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_title: str
node_type: NodeType
node_run_index: int = 1 # FIXME(-LAN-): may not used
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
class QueueNodeSucceededEvent(AppQueueEvent):
"""
@ -360,6 +396,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
@ -380,6 +417,10 @@ class QueueNodeSucceededEvent(AppQueueEvent):
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: Optional[str] = None
"""single iteration duration map"""
iteration_duration_map: Optional[dict[str, float]] = None
"""single loop duration map"""
loop_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
@ -413,6 +454,72 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeInLoopFailedEvent(AppQueueEvent):
"""
QueueNodeInLoopFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
@ -423,6 +530,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
@ -455,7 +563,15 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
@ -562,3 +678,61 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
error: str

View File

@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@ -71,6 +71,8 @@ class StreamEvent(Enum):
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -438,6 +440,54 @@ class NodeRetryStreamResponse(StreamResponse):
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -456,6 +506,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -478,7 +530,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -510,6 +567,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Optional[Mapping] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -563,6 +622,7 @@ class LoopNodeNextStreamResponse(StreamResponse):
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str

View File

@ -28,6 +28,7 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import (
@ -40,7 +41,6 @@ from models.provider import (
ProviderType,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@ -627,7 +627,6 @@ class ProviderConfiguration(BaseModel):
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1125,7 +1124,6 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1209,7 +1207,6 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():

View File

@ -12,8 +12,8 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@ -28,9 +28,8 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.node_events import AgentLogEvent
from extensions.ext_database import db
from models import App, Message, WorkflowNodeExecutionModel
from core.workflow.graph_engine.entities.event import AgentLogEvent
from models import App, Message, WorkflowNodeExecutionModel, db
logger = logging.getLogger(__name__)

View File

@ -167,11 +167,11 @@ class TokenBufferMemory:
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
if not prompt_messages:
return []
if not prompt_messages:
return []
# prune the chat message if it exceeds the max token limit
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
# prune the chat message if it exceeds the max token limit
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:

View File

@ -24,7 +24,8 @@ from core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.impl.model import PluginModelClient
class AIModel(BaseModel):
@ -52,8 +53,6 @@ class AIModel(BaseModel):
:return: Invoke error mapping
"""
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
@ -141,8 +140,6 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import (
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -141,8 +142,6 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
result = plugin_model_manager.invoke_llm(
tenant_id=self.tenant_id,
@ -341,8 +340,6 @@ class LargeLanguageModel(AIModel):
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,

View File

@ -5,6 +5,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class ModerationModel(AIModel):
@ -30,8 +31,6 @@ class ModerationModel(AIModel):
self.started_at = time.perf_counter()
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,

View File

@ -3,6 +3,7 @@ from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class RerankModel(AIModel):
@ -35,8 +36,6 @@ class RerankModel(AIModel):
:return: rerank result
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,

View File

@ -4,6 +4,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class Speech2TextModel(AIModel):
@ -27,8 +28,6 @@ class Speech2TextModel(AIModel):
:return: text for given audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,

View File

@ -6,6 +6,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
class TextEmbeddingModel(AIModel):
@ -36,8 +37,6 @@ class TextEmbeddingModel(AIModel):
:param input_type: input type
:return: embeddings result
"""
from core.plugin.impl.model import PluginModelClient
try:
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
@ -62,8 +61,6 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,

View File

@ -6,6 +6,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -41,8 +42,6 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
@ -66,8 +65,6 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
from core.plugin.impl.model import PluginModelClient
plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,

View File

@ -20,8 +20,10 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin import ModelProviderID
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from models.provider_ids import ModelProviderID
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -35,8 +37,6 @@ class ModelProviderFactory:
provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None:
from core.plugin.impl.model import PluginModelClient
self.provider_position_map = {}
self.tenant_id = tenant_id
@ -71,7 +71,7 @@ class ModelProviderFactory:
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@ -109,7 +109,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name
@ -366,8 +366,6 @@ class ModelProviderFactory:
mime_type = image_mime_types.get(extension, "image/png")
# get icon bytes from plugin asset manager
from core.plugin.impl.asset import PluginAssetManager
plugin_asset_manager = PluginAssetManager()
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@ -377,6 +375,5 @@ class ModelProviderFactory:
:param provider: provider name
:return: plugin id and provider name
"""
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name

View File

@ -54,10 +54,13 @@ from core.ops.entities.trace_entity import (
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes import NodeType
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
logger = logging.getLogger(__name__)

View File

@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus

View File

@ -28,7 +28,8 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -22,7 +22,8 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -5,7 +5,7 @@ import queue
import threading
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import Any, Optional, Union
from uuid import UUID, uuid4
from cachetools import LRUCache
@ -30,15 +30,13 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
from core.workflow.entities import WorkflowExecution
logger = logging.getLogger(__name__)
@ -412,7 +410,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: Optional[str] = None,
workflow_execution: Optional["WorkflowExecution"] = None,
workflow_execution: Optional[WorkflowExecution] = None,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,

View File

@ -23,7 +23,8 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom

View File

@ -164,6 +164,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
call_depth=1,
workflow_thread_pool_id=None,
)
@classmethod

View File

@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.enums import NodeType
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)

View File

@ -1,11 +1,11 @@
import enum
import json
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.tools.entities.common_entities import I18nObject
from core.workflow.nodes.base.entities import NumberType
class PluginParameterOption(BaseModel):
@ -154,7 +154,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError("The tools selector must be a list.")
return value
case PluginParameterType.ANY:
if value and not isinstance(value, str | dict | list | int | float):
if value and not isinstance(value, str | dict | list | NumberType):
raise ValueError("The var selector must be a string, dictionary, list or number.")
return value
case PluginParameterType.ARRAY:
@ -162,6 +162,8 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for arrays
if isinstance(value, str):
try:
import json
parsed_value = json.loads(value)
if isinstance(parsed_value, list):
return parsed_value
@ -174,6 +176,8 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
# Try to parse JSON string for objects
if isinstance(value, str):
try:
import json
parsed_value = json.loads(value)
if isinstance(parsed_value, dict):
return parsed_value

View File

@ -1,9 +1,11 @@
import datetime
import enum
import re
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
from werkzeug.exceptions import NotFound
from core.agent.plugin_entities import AgentStrategyProviderEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
@ -133,6 +135,55 @@ class PluginEntity(PluginInstallation):
return self
class GenericProviderID:
organization: str
plugin_name: str
provider_name: str
is_hardcoded: bool
def to_string(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
if not value:
raise NotFound("plugin not found, please add plugin")
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
if re.match(r"^[a-z0-9_-]+$", value):
value = f"langgenius/{value}/{value}"
else:
raise ValueError(f"Invalid plugin id {value}")
self.organization, self.plugin_name, self.provider_name = value.split("/")
self.is_hardcoded = is_hardcoded
def is_langgenius(self) -> bool:
return self.organization == "langgenius"
@property
def plugin_id(self) -> str:
return f"{self.organization}/{self.plugin_name}"
class ModelProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
if self.organization == "langgenius" and self.provider_name == "google":
self.plugin_name = "gemini"
class ToolProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool"
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@ -2,13 +2,13 @@ from collections.abc import Generator
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from models.provider_ids import GenericProviderID
class PluginAgentClient(BasePluginClient):

View File

@ -1,9 +1,9 @@
from collections.abc import Mapping
from typing import Any
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
from core.plugin.impl.base import BasePluginClient
from models.provider_ids import GenericProviderID
class DynamicSelectClient(BasePluginClient):

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
MissingPluginDependency,
PluginDeclaration,
PluginEntity,
@ -15,7 +16,6 @@ from core.plugin.entities.plugin_daemon import (
PluginListResponse,
)
from core.plugin.impl.base import BasePluginClient
from models.provider_ids import GenericProviderID
class PluginInstaller(BasePluginClient):

View File

@ -3,11 +3,11 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
from models.provider_ids import GenericProviderID, ToolProviderID
class PluginToolManager(BasePluginClient):

View File

@ -34,6 +34,7 @@ from core.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions import ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -48,7 +49,6 @@ from models.provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService

View File

@ -24,7 +24,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}

View File

@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(

View File

@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,

View File

@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],

View File

@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(

View File

@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs

View File

@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
score = 1 - dis
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs

View File

@ -195,7 +195,7 @@ class PGVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -0,0 +1,341 @@
import json
import time
from typing import Any, Optional
from pinecone import Pinecone, ServerlessSpec
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
class PineconeConfig(BaseModel):
"""Pinecone configuration class"""
api_key: str
environment: str
index_name: Optional[str] = None
timeout: float = 30
batch_size: int = 100
metric: str = "cosine"
class PineconeVector(BaseVector):
"""Pinecone vector database concrete implementation class"""
def __init__(self, collection_name: str, group_id: str, config: PineconeConfig):
super().__init__(collection_name)
self._client_config = config
self._group_id = group_id
# Initialize Pinecone client with SSL configuration
try:
self._pc = Pinecone(
api_key=config.api_key,
# Configure SSL to handle connection issues
ssl_ca_certs=None, # Use system default CA certificates
)
except Exception as e:
# Fallback to basic initialization if SSL config fails
self._pc = Pinecone(api_key=config.api_key)
# Normalize index name: lowercase, only a-z0-9- and <=45 chars
import re, hashlib
base_name = collection_name.lower()
base_name = re.sub(r'[^a-z0-9-]+', '-', base_name) # replace invalid chars with '-'
base_name = re.sub(r'-+', '-', base_name).strip('-')
# Use longer secure suffix to reduce collision risk
suffix_len = 24 # 24 hex digits (96-bit entropy)
if len(base_name) > 45:
hash_suffix = hashlib.sha256(base_name.encode()).hexdigest()[:suffix_len]
truncated_name = base_name[:45-(suffix_len+1)].rstrip('-')
self._index_name = f"{truncated_name}-{hash_suffix}"
else:
self._index_name = base_name
# Guard empty name
if not self._index_name:
self._index_name = f"index-{hashlib.sha256(collection_name.encode()).hexdigest()[:suffix_len]}"
self._index = None
def get_type(self) -> str:
"""Return vector database type identifier"""
return "pinecone"
def _ensure_index_initialized(self) -> None:
"""Ensure that self._index is attached to an existing Pinecone index."""
if self._index is not None:
return
try:
existing_indexes = self._pc.list_indexes().names()
if self._index_name in existing_indexes:
self._index = self._pc.Index(self._index_name)
else:
raise ValueError("Index not initialized. Please ingest documents to create index.")
except Exception:
raise
def to_index_struct(self) -> dict:
"""Generate index structure dictionary"""
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create vector index"""
if texts:
# Get vector dimension
vector_size = len(embeddings[0])
# Create Pinecone index
self.create_index(vector_size)
# Add vector data
self.add_texts(texts, embeddings, **kwargs)
def create_index(self, dimension: int):
"""Create Pinecone index"""
lock_name = f"vector_indexing_lock_{self._index_name}"
with redis_client.lock(lock_name, timeout=30):
# Check Redis cache
index_exist_cache_key = f"vector_indexing_{self._index_name}"
if redis_client.get(index_exist_cache_key):
self._index = self._pc.Index(self._index_name)
return
# Check if index already exists
existing_indexes = self._pc.list_indexes().names()
if self._index_name not in existing_indexes:
# Create new index using ServerlessSpec
self._pc.create_index(
name=self._index_name,
dimension=dimension,
metric=self._client_config.metric,
spec=ServerlessSpec(
cloud='aws',
region=self._client_config.environment
)
)
# Wait for index creation to complete
while not self._pc.describe_index(self._index_name).status['ready']:
time.sleep(1)
else:
# Get index instance
self._index = self._pc.Index(self._index_name)
# Set cache
redis_client.set(index_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Batch add document vectors"""
if not self._index:
raise ValueError("Index not initialized. Call create() first.")
total_docs = len(documents)
uuids = self._get_uuids(documents)
batch_size = self._client_config.batch_size
added_ids = []
# Batch processing
total_batches = (total_docs + batch_size - 1) // batch_size # Ceiling division
for batch_idx, i in enumerate(range(0, len(documents), batch_size), 1):
batch_documents = documents[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
batch_uuids = uuids[i:i + batch_size]
batch_size_actual = len(batch_documents)
# Build Pinecone vector data (metadata must be primitives or list[str])
vectors_to_upsert = []
for doc, embedding, doc_id in zip(batch_documents, batch_embeddings, batch_uuids):
raw_meta = doc.metadata or {}
safe_meta: dict[str, Any] = {}
# lift common identifiers to top-level fields for filtering
for k, v in raw_meta.items():
if isinstance(v, (str, int, float, bool)):
safe_meta[k] = v
elif isinstance(v, list) and all(isinstance(x, str) for x in v):
safe_meta[k] = v
else:
safe_meta[k] = json.dumps(v, ensure_ascii=False)
# keep content as string metadata if needed
safe_meta[Field.CONTENT_KEY.value] = doc.page_content
# group id as string
safe_meta[Field.GROUP_KEY.value] = str(self._group_id)
vectors_to_upsert.append({
"id": doc_id,
"values": embedding,
"metadata": safe_meta
})
# Batch insert to Pinecone
try:
self._index.upsert(vectors=vectors_to_upsert)
added_ids.extend(batch_uuids)
except Exception as e:
raise
return added_ids
def search_by_vector(self, query_vector: list[float], **kwargs) -> list[Document]:
"""Vector similarity search"""
# Lazily attach to an existing index if needed
self._ensure_index_initialized()
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold", 0.0))
# Build filter conditions
filter_dict = {Field.GROUP_KEY.value: {"$eq": str(self._group_id)}}
# Document scope filtering
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_dict["document_id"] = {"$in": document_ids_filter}
# Execute search
try:
response = self._index.query(
vector=query_vector,
top_k=top_k,
include_metadata=True,
filter=filter_dict
)
except Exception as e:
raise
# Convert results
docs = []
filtered_count = 0
for match in response.matches:
if match.score >= score_threshold:
page_content = match.metadata.get(Field.CONTENT_KEY.value, "")
metadata = dict(match.metadata or {})
metadata.pop(Field.CONTENT_KEY.value, None)
metadata.pop(Field.GROUP_KEY.value, None)
metadata["score"] = match.score
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
else:
filtered_count += 1
# Sort by similarity score in descending order
docs.sort(key=lambda x: x.metadata.get("score", 0), reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs) -> list[Document]:
"""Full-text search - Pinecone does not natively support it, returns empty list"""
return []
def delete_by_metadata_field(self, key: str, value: str):
"""Delete by metadata field"""
self._ensure_index_initialized()
try:
# Build filter conditions
filter_dict = {
Field.GROUP_KEY.value: {"$eq": self._group_id},
f"{Field.METADATA_KEY.value}.{key}": {"$eq": value}
}
# Pinecone delete operation
self._index.delete(filter=filter_dict)
except Exception as e:
# Ignore delete errors
pass
def delete_by_ids(self, ids: list[str]) -> None:
"""Batch delete by ID list"""
self._ensure_index_initialized()
try:
# Pinecone delete by ID
self._index.delete(ids=ids)
except Exception as e:
raise
def delete(self) -> None:
"""Delete all vector data for the entire dataset"""
self._ensure_index_initialized()
try:
# Delete all vectors by group_id
filter_dict = {Field.GROUP_KEY.value: {"$eq": self._group_id}}
self._index.delete(filter=filter_dict)
except Exception as e:
raise
def text_exists(self, id: str) -> bool:
"""Check if document exists"""
try:
self._ensure_index_initialized()
except Exception:
return False
try:
# Check if vector exists through query
response = self._index.fetch(ids=[id])
exists = id in response.vectors
return exists
except Exception as e:
return False
class PineconeVectorFactory(AbstractVectorFactory):
"""Pinecone vector database factory class"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PineconeVector:
"""Create PineconeVector instance"""
# Determine index name
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError("Dataset Collection Bindings does not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
# Set index structure
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict("pinecone", collection_name)
)
# Create PineconeVector instance
return PineconeVector(
collection_name=collection_name,
group_id=dataset.id,
config=PineconeConfig(
api_key=dify_config.PINECONE_API_KEY or "",
environment=dify_config.PINECONE_ENVIRONMENT or "",
index_name=dify_config.PINECONE_INDEX_NAME,
timeout=dify_config.PINECONE_CLIENT_TIMEOUT,
batch_size=dify_config.PINECONE_BATCH_SIZE,
metric=dify_config.PINECONE_METRIC,
),
)

View File

@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -369,7 +369,7 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),

View File

@ -233,7 +233,7 @@ class RelytVector(BaseVector):
docs = []
for document, score in results:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
if 1 - score >= score_threshold:
docs.append(document)
return docs

View File

@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
)
documents = []
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
if search_hit.score >= score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]

View File

@ -291,7 +291,7 @@ class TencentVector(BaseVector):
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -351,7 +351,7 @@ class TidbOnQdrantVector(BaseVector):
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),

View File

@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
score = record.score
if metadata is not None and text is not None:
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -86,6 +86,10 @@ class Vector:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.PINECONE:
from core.rag.datasource.vdb.pinecone.pinecone_vector import PineconeVectorFactory
return PineconeVectorFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory

View File

@ -31,3 +31,4 @@ class VectorType(StrEnum):
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
PINECONE = "pinecone"

View File

@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -10,6 +10,23 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
def _format_cell_value(value) -> str:
if pd.isna(value):
return ""
if isinstance(value, (int, float)):
if isinstance(value, float):
if value.is_integer():
return str(int(value))
else:
formatted = f"{value:f}"
return formatted.rstrip('0').rstrip('.')
else:
return str(value)
return str(value)
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
@ -49,10 +66,12 @@ class ExcelExtractor(BaseExtractor):
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
formatted_v = _format_cell_value(v)
value = f"[{formatted_v}]({cell.hyperlink.target})"
page_content.append(f'"{k}":"{value}"')
else:
page_content.append(f'"{k}":"{v}"')
formatted_v = _format_cell_value(v)
page_content.append(f'"{k}":"{formatted_v}"')
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
@ -67,7 +86,8 @@ class ExcelExtractor(BaseExtractor):
page_content = []
for k, v in row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
formatted_v = _format_cell_value(v)
page_content.append(f'"{k}":"{formatted_v}"')
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)

View File

@ -1,9 +1,10 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.splitter.fixed_text_splitter import (
@ -13,9 +14,6 @@ from core.rag.splitter.fixed_text_splitter import (
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
if TYPE_CHECKING:
from core.model_manager import ModelInstance
class BaseIndexProcessor(ABC):
"""Interface for extract files."""
@ -53,7 +51,7 @@ class BaseIndexProcessor(ABC):
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional["ModelInstance"],
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.

View File

@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -162,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -647,7 +647,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -743,7 +743,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
top_k=retrieve_config.top_k or 4,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,

View File

@ -9,8 +9,11 @@ from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.entities.workflow_execution import (
WorkflowExecution,
WorkflowExecutionStatus,
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
@ -200,4 +203,5 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
session.commit()
# Update the in-memory cache for faster subsequent lookups
logger.debug("Updating cache for execution_id: %s", db_model.id)
self._execution_cache[db_model.id] = db_model

View File

@ -12,8 +12,12 @@ from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.helper import extract_tenant_id
@ -211,6 +215,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Update the in-memory cache for faster subsequent lookups
# Only cache if we have a node_execution_id to use as the cache key
if db_model.node_execution_id:
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_db_models_by_workflow_run(

View File

@ -152,6 +152,7 @@ class ToolEngine:
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
@ -165,6 +166,7 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -13,16 +13,31 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -38,28 +53,16 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.nodes.tool.entities import ToolEntity
logger = logging.getLogger(__name__)
@ -114,8 +117,6 @@ class ToolManager:
get the plugin provider
"""
# check if context is set
from core.plugin.impl.tool import PluginToolManager
try:
contexts.plugin_tool_providers.get()
except LookupError:
@ -171,7 +172,6 @@ class ToolManager:
:return: the tool
"""
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
@ -216,16 +216,16 @@ class ToolManager:
# fallback to the default provider
if builtin_provider is None:
# use the default provider
with Session(db.engine) as session:
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
@ -256,7 +256,6 @@ class ToolManager:
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
@ -264,7 +263,6 @@ class ToolManager:
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
@ -360,7 +358,7 @@ class ToolManager:
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: Optional[VariablePool] = None,
) -> Tool:
"""
get the agent tool runtime
@ -402,7 +400,7 @@ class ToolManager:
node_id: str,
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: Optional[VariablePool] = None,
) -> Tool:
"""
get the workflow tool runtime
@ -518,8 +516,6 @@ class ToolManager:
"""
list all the plugin providers
"""
from core.plugin.impl.tool import PluginToolManager
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)
return [
@ -981,7 +977,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional["VariablePool"],
variable_pool: Optional[VariablePool],
tool_configurations: dict[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@ -181,7 +181,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
top_k: int = 4
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool

View File

@ -37,12 +37,14 @@ class WorkflowTool(Tool):
entity: ToolEntity,
runtime: ToolRuntime,
label: str = "Workflow",
thread_pool_id: Optional[str] = None,
):
self.workflow_app_id = workflow_app_id
self.workflow_as_tool_id = workflow_as_tool_id
self.version = version
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.thread_pool_id = thread_pool_id
self.label = label
super().__init__(entity=entity, runtime=runtime)
@ -86,6 +88,7 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id,
)
assert isinstance(result, dict)
data = result.get("data", {})

View File

@ -130,7 +130,7 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
items.append(f"- {item}")
items.append(str(item))
return "\n".join(items)

Some files were not shown because too many files have changed in this diff Show More