Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -1 +0,0 @@
import core.moderation.base

View File

@ -1,7 +1,7 @@
import json
import logging
import uuid
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
@ -60,9 +60,9 @@ class BaseAgentRunner(AppRunner):
message: Message,
user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
) -> None:
memory: TokenBufferMemory | None = None,
prompt_messages: list[PromptMessage] | None = None,
):
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
tenant_id=tenant_id,
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
return_resource=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
user_id=user_id,
@ -112,7 +114,7 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.query: Optional[str] = ""
self.query: str | None = ""
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
@ -334,7 +336,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@ -492,7 +495,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
@ -70,10 +70,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@ -120,7 +122,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
usage_dict: dict[str, Optional[LLMUsage]] = {}
usage_dict: dict[str, LLMUsage | None] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@ -272,7 +274,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
trace_manager: TraceQueueManager | None = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
@ -338,7 +340,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return instruction
def _init_react_state(self, query) -> None:
def _init_react_state(self, query):
"""
init agent scratchpad
"""

View File

@ -1,5 +1,4 @@
import json
from typing import Optional
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
@ -31,7 +30,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
return system_prompt
def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str:
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
"""
Organize historic prompt
"""

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Optional, Union
from typing import Any, Union
from pydantic import BaseModel, Field
@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
action_name: str
action_input: Union[dict, str]
def to_dict(self) -> dict:
def to_dict(self):
"""
Convert to dictionary.
"""
@ -50,11 +50,11 @@ class AgentScratchpadUnit(BaseModel):
"action_input": self.action_input,
}
agent_response: Optional[str] = None
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
agent_response: str | None = None
thought: str | None = None
action_str: str | None = None
observation: str | None = None
action: Action | None = None
def is_final(self) -> bool:
"""
@ -81,8 +81,8 @@ class AgentEntity(BaseModel):
provider: str
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: Optional[list[AgentToolEntity]] = None
prompt: AgentPromptEntity | None = None
tools: list[AgentToolEntity] | None = None
max_iteration: int = 10

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Optional, Union
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
@ -52,13 +52,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:

View File

@ -1,5 +1,5 @@
import enum
from typing import Any, Optional
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity):
class AgentStrategyParameter(PluginParameter):
class AgentStrategyParameterType(enum.StrEnum):
class AgentStrategyParameterType(StrEnum):
"""
Keep all the types from PluginParameterType
"""
STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value
SELECT = CommonParameterType.SELECT.value
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
ANY = CommonParameterType.ANY.value
STRING = CommonParameterType.STRING
NUMBER = CommonParameterType.NUMBER
BOOLEAN = CommonParameterType.BOOLEAN
SELECT = CommonParameterType.SELECT
SECRET_INPUT = CommonParameterType.SECRET_INPUT
FILE = CommonParameterType.FILE
FILES = CommonParameterType.FILES
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)
@ -53,7 +53,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
help: I18nObject | None = None
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
@ -61,7 +61,7 @@ class AgentStrategyParameter(PluginParameter):
class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
plugin_id: str | None = Field(None, description="The id of the plugin")
class AgentStrategyIdentity(ToolIdentity):
@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity):
pass
class AgentFeature(enum.StrEnum):
class AgentFeature(StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""
@ -84,9 +84,9 @@ class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
meta_version: Optional[str] = None
output_schema: dict | None = None
features: list[AgentFeature] | None = None
meta_version: str | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from typing import Any, Optional
from typing import Any
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
@ -16,10 +16,10 @@ class BaseAgentStrategy(ABC):
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@ -37,9 +37,9 @@ class BaseAgentStrategy(ABC):
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Any, Optional
from typing import Any
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
@ -38,10 +38,10 @@ class PluginAgentStrategy(BaseAgentStrategy):
self,
params: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@ -21,7 +19,7 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
@ -38,7 +36,14 @@ class SensitiveWordAvoidanceConfigManager:
if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[AgentEntity]:
def convert(cls, config: dict) -> AgentEntity | None:
"""
Convert model config to model config

View File

@ -1,5 +1,4 @@
import uuid
from typing import Optional
from core.app.app_config.entities import (
DatasetEntity,
@ -14,7 +13,7 @@ from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[DatasetEntity]:
def convert(cls, config: dict) -> DatasetEntity | None:
"""
Convert model config to model config
@ -158,7 +157,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
"""
Extract dataset config for legacy compatibility

View File

@ -4,8 +4,8 @@ from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from models.provider_ids import ModelProviderID
class ModelConfigManager:
@ -105,7 +105,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:
def validate_model_completion_params(cls, cp: dict):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append(
AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
@ -66,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@ -86,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
@ -122,7 +126,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
def validate_post_prompt_and_set_defaults(cls, config: dict):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
from enum import StrEnum, auto
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel):
provider: str
model: str
mode: Optional[str] = None
mode: str | None = None
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)
@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
role_prefix: RolePrefixEntity | None = None
class PromptTemplateEntity(BaseModel):
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""
class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()
@classmethod
def value_of(cls, value: str):
@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel):
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
simple_prompt_template: str | None = None
advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
@ -112,11 +112,11 @@ class VariableEntity(BaseModel):
type: VariableEntityType
required: bool = False
hide: bool = False
max_length: Optional[int] = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: str | None = None
placeholder: str | None = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
@ -173,7 +183,7 @@ class ModelConfig(BaseModel):
class Condition(BaseModel):
"""
Conditon detail
Condition detail
"""
name: str
@ -186,8 +196,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
@ -195,14 +205,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()
@classmethod
def value_of(cls, value: str):
@ -217,18 +227,18 @@ class DatasetRetrieveConfigEntity(BaseModel):
return mode
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
query_variable: str | None = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
class DatasetEntity(BaseModel):
@ -255,8 +265,8 @@ class TextToSpeechEntity(BaseModel):
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
voice: str | None = None
language: str | None = None
class TracingConfigEntity(BaseModel):
@ -269,15 +279,15 @@ class TracingConfigEntity(BaseModel):
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None
file_upload: FileUploadConfig | None = None
opening_statement: str | None = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
text_to_speech: TextToSpeechEntity | None = None
trace_config: TracingConfigEntity | None = None
class AppConfig(BaseModel):
@ -288,17 +298,17 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""
ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
@ -313,7 +323,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_dict: dict
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: Optional[DatasetEntity] = None
dataset: DatasetEntity | None = None
external_data_variables: list[ExternalDataVariableEntity] = []

View File

@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
except ValidationError:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -1,4 +1,6 @@
from core.app.app_config.entities import VariableEntity
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow
@ -20,3 +22,48 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part, None)
if value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part, None)
all_second_step_variables = list(variables_map.values())
for item in all_second_step_variables:
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(item))
return variables

View File

@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for advanced chat app model config

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
@ -420,7 +420,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
db.session.refresh(conversation)
# get conversation dialogue count
self._dialogue_count = get_thread_messages_length(conversation.id)
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -450,6 +452,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
@ -461,7 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -475,7 +483,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,11 +1,11 @@
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType
from models.workflow import ConversationVariable
logger = logging.getLogger(__name__)
@ -54,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
@ -68,31 +69,22 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.system_user_id = system_user_id
self._app = app
def run(self) -> None:
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
@ -144,16 +136,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
db.session.close()
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -165,11 +168,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)
@ -219,7 +222,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
"""
Direct output
"""
@ -229,7 +232,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from threading import Thread
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -65,15 +60,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@ -102,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -175,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -234,7 +228,7 @@ class AdvancedChatAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -290,12 +284,12 @@ class AdvancedChatAppGenerateTaskPipeline:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -303,21 +297,16 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
@ -338,15 +327,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
yield node_retry_resp
@ -380,13 +368,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
@ -395,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -419,8 +404,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -442,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -546,8 +505,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -577,8 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -609,8 +568,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
@ -635,17 +594,17 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event(
self,
event: QueueStopEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
@ -685,13 +644,13 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
graph_runtime_state: GraphRuntimeState | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
@ -759,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -783,10 +740,10 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -808,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -822,31 +777,20 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
@ -856,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break
@ -896,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -907,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
@ -939,10 +878,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
@ -967,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
@ -979,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return True
else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, cast
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
agent: AgentEntity | None = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
"""
Validate for agent chat app model config
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)
if not isinstance(config["agent_mode"]["enabled"], bool):
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not agent_mode.get("tools"):
agent_mode["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]:
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key

View File

@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@ -33,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run assistant application
:param application_generate_entity: application generate entity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict:
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception

View File

@ -1,12 +1,12 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,
@ -14,6 +14,7 @@ from core.workflow.repositories.draft_variable_repository import (
)
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
@ -24,7 +25,7 @@ class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
user_inputs: Mapping[str, Any] | None,
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,
@ -44,9 +45,9 @@ class BaseAppGenerator:
mapping=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
)
@ -59,9 +60,9 @@ class BaseAppGenerator:
mappings=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
)
for k, v in user_inputs.items()
@ -157,7 +158,7 @@ class BaseAppGenerator:
return value
def _sanitize_value(self, value: Any) -> Any:
def _sanitize_value(self, value: Any):
if isinstance(value, str):
return value.replace("\x00", "")
return value
@ -182,8 +183,9 @@ class BaseAppGenerator:
@final
@staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
def draft_var_saver_factory(
session: Session,
@ -200,6 +202,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
)
else:

View File

@ -2,7 +2,7 @@ import queue
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any, Optional
from typing import Any
from sqlalchemy.orm import DeclarativeMeta
@ -25,13 +25,14 @@ class PublishFrom(IntEnum):
class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
@ -73,14 +74,14 @@ class AppQueueManager:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
def stop_listen(self):
"""
Stop listen to queue
:return:
"""
self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom) -> None:
def publish_error(self, e, pub_from: PublishFrom):
"""
Publish error
:param e: error
@ -89,7 +90,7 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@ -100,7 +101,7 @@ class AppQueueManager:
self._publish(event, pub_from)
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@ -110,12 +111,12 @@ class AppQueueManager:
raise NotImplementedError
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
"""
Set task stop flag
:return:
"""
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
@ -126,6 +127,21 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
@ -159,7 +175,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -82,11 +82,11 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
query: str | None = None,
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
:param context:
@ -161,8 +161,8 @@ class AppRunner:
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
usage: LLMUsage | None = None,
):
"""
Direct output
:param queue_manager: application queue manager
@ -204,7 +204,7 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result
@ -220,9 +220,7 @@ class AppRunner:
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
"""
Handle invoke result direct
:param invoke_result: invoke result
@ -239,7 +237,7 @@ class AppRunner:
def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result
@ -377,7 +375,7 @@ class AppRunner:
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
@ -81,7 +79,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for chat app model config

View File

@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -31,7 +33,7 @@ class ChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
@ -162,7 +164,9 @@ class ChatAppRunner(AppRunner):
config=app_config.dataset,
query=query,
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
show_retrieve_source=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
hit_callback=hit_callback,
memory=memory,
message_id=message.id,

View File

@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from typing import Any, Union
from sqlalchemy.orm import Session
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
@ -36,24 +31,23 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
EndUser,
)
from services.variable_truncator import VariableTruncator
class WorkflowResponseConverter:
@ -62,9 +56,10 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._user = user
self._truncator = VariableTruncator.default()
def workflow_start_to_stream_response(
self,
@ -140,7 +135,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -156,7 +151,8 @@ class WorkflowResponseConverter:
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
@ -171,11 +167,19 @@ class WorkflowResponseConverter:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
return response
@ -183,14 +187,10 @@ class WorkflowResponseConverter:
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -210,9 +210,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -221,9 +224,6 @@ class WorkflowResponseConverter:
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@ -235,7 +235,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -255,9 +255,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -275,50 +278,6 @@ class WorkflowResponseConverter:
),
)
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
@ -326,6 +285,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -333,13 +293,12 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -357,15 +316,10 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -377,6 +331,11 @@ class WorkflowResponseConverter:
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -384,28 +343,29 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=json_converter.to_json_encodable(event.outputs),
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -413,10 +373,11 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
@ -437,15 +398,16 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_loop_output=event.output,
# The `pre_loop_output` field is not utilized by the frontend.
# Previously, it was assigned the value of `event.output`.
pre_loop_output={},
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -456,6 +418,11 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -463,17 +430,19 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
@ -66,7 +64,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for completion app model config

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -191,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -248,28 +249,30 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -25,7 +27,7 @@ class CompletionAppRunner(AppRunner):
def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
@ -122,7 +124,9 @@ class CompletionAppRunner(AppRunner):
config=dataset_config,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
show_retrieve_source=app_config.additional_features.show_retrieve_source
if app_config.additional_features
else False,
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,

View File

@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,10 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
@ -81,13 +84,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -110,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None:
):
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@ -0,0 +1,66 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -0,0 +1,856 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories.factory import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
original_document_id=args.get("original_document_id"),
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
user,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
)
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %r",
application_generate_entity.task_id,
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=datasource_node_data.get("credential_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("id") and datasource_info.get("type") == "folder":
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("id", ""),
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": datasource_info.get("id", ""),
"name": datasource_info.get("name", "untitled"),
"bucket": datasource_info.get("bucket", None),
}
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
next_page_parameters=next_page_parameters,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
for result in result_generator:
for files in result.result:
for file in files.files:
if file.type == "folder":
self._get_files_in_folder(
datasource_runtime,
file.id,
bucket,
user_id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": file.id,
"name": file.name,
"bucket": bucket,
}
)
is_truncated = files.is_truncated
next_page_parameters = files.next_page_parameters
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
)

View File

@ -0,0 +1,45 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@ -0,0 +1,263 @@
import logging
import time
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_rag_pipeline_graph(
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
)
generator = workflow_entry.run()
for event in generator:
self._update_document_status(
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
"""
graph_config = workflow.graph_dict
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# nodes = graph_config.get("nodes", [])
# edges = graph_config.get("edges", [])
# real_run_nodes = []
# real_edges = []
# exclude_node_ids = []
# for node in nodes:
# node_id = node.get("id")
# node_type = node.get("data", {}).get("type", "")
# if node_type == "datasource":
# if start_node_id != node_id:
# exclude_node_ids.append(node_id)
# continue
# real_run_nodes.append(node)
# for edge in edges:
# if edge.get("source") in exclude_node_ids:
# continue
# real_edges.append(edge)
# graph_config = dict(graph_config)
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
"""
Update document status
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()

View File

@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for workflow app model config

View File

@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Generator[Mapping | str, None, None]: ...
@ -69,7 +68,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@ -85,7 +83,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
@ -100,7 +97,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
@ -200,7 +196,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
root_node_id=root_node_id,
)
@ -215,7 +210,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
@ -230,7 +224,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -253,7 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
},
@ -261,9 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
# return response or stream generator
response = self._handle_response(
@ -451,7 +441,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
root_node_id: Optional[str] = None,
) -> None:
"""
@ -462,7 +451,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
@ -492,7 +480,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,

View File

@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

@ -1,7 +1,7 @@
import logging
from typing import Optional, cast
import time
from typing import cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -31,47 +32,33 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
root_node_id: Optional[str] = None,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
def run(self) -> None:
def run(self):
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
@ -94,15 +81,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
root_node_id=self._root_node_id
)
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -114,10 +114,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Any, Optional, Union
from typing import Union
from sqlalchemy.orm import Session
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -92,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -137,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -146,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -206,7 +202,7 @@ class WorkflowAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -263,12 +259,12 @@ class WorkflowAppGenerateTaskPipeline:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -276,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
@ -300,16 +296,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
@ -350,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -371,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -475,8 +442,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -509,8 +476,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -544,8 +511,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
@ -582,8 +549,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -618,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -634,12 +599,12 @@ class WorkflowAppGenerateTaskPipeline:
def _dispatch_event(
self,
event: Any,
event: AppQueueEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -661,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -675,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
@ -702,8 +654,8 @@ class WorkflowAppGenerateTaskPipeline:
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
@ -745,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@ -770,7 +722,7 @@ class WorkflowAppGenerateTaskPipeline:
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.

View File

@ -1,7 +1,9 @@
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -13,14 +15,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -28,42 +25,39 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeInIterationFailedEvent,
NodeInLoopFailedEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
@ -74,12 +68,20 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
root_node_id: Optional[str] = None
) -> Graph:
"""
Init graph
"""
@ -91,22 +93,109 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _get_graph_and_variable_pool_of_single_iteration(
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
"""
# fetch workflow graph
graph_config = workflow.graph_dict
@ -124,18 +213,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
@ -145,37 +238,50 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
iteration_node_config = None
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
iteration_node_config = node
target_node_config = node
break
if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
@ -196,103 +302,45 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
:param workflow_entry: workflow entry
@ -310,39 +358,32 @@ class WorkflowBasedAppRunner:
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunStartedEvent):
@ -350,44 +391,29 @@ class WorkflowBasedAppRunner:
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
start_at=event.start_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -396,34 +422,18 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -434,93 +444,21 @@ class WorkflowBasedAppRunner:
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeInLoopFailedEvent):
self._publish_event(
QueueNodeInLoopFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -533,10 +471,10 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, AgentLogEvent):
elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
@ -547,51 +485,13 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -599,55 +499,41 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, IterationRunNextEvent):
elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
)
)
elif isinstance(event, LoopRunStartedEvent):
elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -655,44 +541,34 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, LoopRunNextEvent):
elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
)
)
def _publish_event(self, event: AppQueueEvent) -> None:
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,9 +1,12 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
@ -11,7 +14,7 @@ from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(Enum):
class InvokeFrom(StrEnum):
"""
Invoke From.
"""
@ -35,6 +38,7 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod
def value_of(cls, value: str):
@ -95,8 +99,8 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
app_config: Any = None
file_upload_config: FileUploadConfig | None = None
inputs: Mapping[str, Any]
files: Sequence[File]
@ -113,8 +117,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
# Using Any to avoid circular import with TraceQueueManager
trace_manager: Optional[Any] = None
trace_manager: Optional["TraceQueueManager"] = None
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@ -123,10 +126,10 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
query: str | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -137,8 +140,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
Base entity for conversation-based app generation.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = Field(
conversation_id: str | None = None
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
@ -186,9 +189,9 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
query: str
class SingleIterationRunEntity(BaseModel):
@ -199,7 +202,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""
@ -209,7 +212,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
@ -218,7 +221,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str
class SingleIterationRunEntity(BaseModel):
@ -229,7 +232,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""
@ -239,4 +242,35 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: str | None = None
original_document_id: str | None = None
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()
ChatAppGenerateEntity.model_rebuild()
CompletionAppGenerateEntity.model_rebuild()
AgentChatAppGenerateEntity.model_rebuild()
AdvancedChatAppGenerateEntity.model_rebuild()
WorkflowAppGenerateEntity.model_rebuild()
RagPipelineGenerateEntity.model_rebuild()

View File

@ -1,17 +1,15 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(StrEnum):
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueIterationNextEvent(AppQueueEvent):
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
node_title: str
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
output: Any = None # output for the current iteration
class QueueIterationCompletedEvent(AppQueueEvent):
@ -134,24 +110,16 @@ class QueueIterationCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueLoopStartEvent(AppQueueEvent):
@ -163,21 +131,21 @@ class QueueLoopStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueLoopNextEvent(AppQueueEvent):
@ -191,20 +159,19 @@ class QueueLoopNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
output: Any = None # output for the current loop
class QueueLoopCompletedEvent(AppQueueEvent):
@ -217,24 +184,24 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueTextChunkEvent(AppQueueEvent):
@ -244,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -285,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -306,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
llm_result: LLMResult | None = None
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueNodeStartedEvent(AppQueueEvent):
@ -364,26 +331,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_title: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
class QueueNodeSucceededEvent(AppQueueEvent):
@ -396,31 +360,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: Optional[str] = None
"""single iteration duration map"""
iteration_duration_map: Optional[dict[str, float]] = None
"""single loop duration map"""
loop_duration_map: Optional[dict[str, float]] = None
error: str | None = None
class QueueAgentLogEvent(AppQueueEvent):
@ -432,11 +391,11 @@ class QueueAgentLogEvent(AppQueueEvent):
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeInLoopFailedEvent(AppQueueEvent):
"""
QueueNodeInLoopFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
@ -530,25 +423,24 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -563,25 +455,17 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
parallel_id: str | None = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.ERROR
error: Optional[Any] = None
error: Any = None
class QueuePingEvent(AppQueueEvent):
@ -626,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""
class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
error: str

View File

@ -0,0 +1,14 @@
from typing import Any
from pydantic import BaseModel
class RagPipelineInvokeEntity(BaseModel):
pipeline_id: str
application_generate_entity: dict[str, Any]
user_id: str
tenant_id: str
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@ -1,14 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@ -51,7 +50,7 @@ class WorkflowTaskState(TaskState):
answer: str = ""
class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""
@ -71,8 +70,6 @@ class StreamEvent(Enum):
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -92,9 +89,6 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ErrorStreamResponse(StreamResponse):
"""
@ -114,7 +108,7 @@ class MessageStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
class MessageAudioStreamResponse(StreamResponse):
@ -142,8 +136,8 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = Field(default_factory=dict)
files: Optional[Sequence[Mapping[str, Any]]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
files: Sequence[Mapping[str, Any]] | None = None
class MessageFileStreamResponse(StreamResponse):
@ -176,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str
position: int
thought: Optional[str] = None
observation: Optional[str] = None
tool: Optional[str] = None
tool_labels: Optional[dict] = None
tool_input: Optional[str] = None
message_files: Optional[list[str]] = None
thought: str | None = None
observation: str | None = None
tool: str | None = None
tool_labels: Mapping[str, object] = Field(default_factory=dict)
tool_input: str | None = None
message_files: list[str] | None = None
class AgentMessageStreamResponse(StreamResponse):
@ -227,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_by: Optional[dict] = None
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -258,18 +252,19 @@ class NodeStartStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
created_at: int
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
parallel_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -315,23 +310,26 @@ class NodeFinishStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -384,23 +382,26 @@ class NodeRetryStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
@ -440,54 +441,6 @@ class NodeRetryStreamResponse(StreamResponse):
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -506,8 +459,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
inputs_truncated: bool = False
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -530,12 +482,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -556,19 +503,19 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -593,8 +540,9 @@ class LoopNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
inputs_truncated: bool = False
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@ -617,12 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_loop_output: Optional[Any] = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@ -643,19 +590,21 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str
@ -673,7 +622,7 @@ class TextChunkStreamResponse(StreamResponse):
"""
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@ -735,7 +684,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
WorkflowAppStreamResponse entity
"""
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
class AppBlockingResponse(BaseModel):
@ -745,9 +694,6 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
@ -764,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
conversation_id: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -784,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
mode: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -803,8 +749,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
@ -828,11 +774,11 @@ class AgentLogStreamResponse(StreamResponse):
node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
event: StreamEvent = StreamEvent.AGENT_LOG

View File

@ -1,5 +1,6 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
@ -25,15 +26,17 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name

View File

@ -1 +1,3 @@
from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@ -3,7 +3,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
@ -63,7 +63,7 @@ class RateLimit:
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
def enter(self, request_id: str | None = None) -> str:
if self.disabled():
return RateLimit._UNLIMITED_REQUEST_ID
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -35,14 +34,14 @@ class BasedGenerateTaskPipeline:
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@ -50,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
@ -86,7 +85,7 @@ class BasedGenerateTaskPipeline:
return message
def _error_to_stream_response(self, e: Exception):
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
@ -94,14 +93,14 @@ class BasedGenerateTaskPipeline:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
"""
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
def _init_output_moderation(self) -> Optional[OutputModeration]:
def _init_output_moderation(self) -> OutputModeration | None:
"""
Init output moderation.
:return:
@ -118,21 +117,21 @@ class BasedGenerateTaskPipeline:
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
completion, flagged = self._output_moderation_handler.moderation_completion(
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
self.output_moderation_handler = None
if flagged:
return completion

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
conversation: Conversation,
message: Message,
stream: bool,
) -> None:
):
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None
self._conversation_name_generate_thread: Thread | None = None
def process(
self,
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self.ping_stream_response()
else:
continue
if publisher:
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""
Save message.
:return:
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
application_generate_entity=self._application_generate_entity,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.
:return:
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
"""
Agent thought to stream response.
:param event: agent thought event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,8 +1,10 @@
import logging
from threading import Thread
from typing import Optional, Union
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@ -46,11 +48,11 @@ class MessageCycleManager:
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
Generate conversation name.
:param conversation_id: conversation id
@ -84,30 +86,32 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
conversation.name = name
except Exception as e:
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
db.session.merge(conversation)
db.session.commit()
db.session.close()
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
"""
Handle annotation reply.
:param event: event
@ -128,22 +132,25 @@ class MessageCycleManager:
return None
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
"""
Handle retriever resources.
:param event: event
:return:
"""
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@ -175,7 +182,7 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
) -> MessageStreamResponse:
"""
Message to stream response.
@ -183,7 +190,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -5,7 +5,6 @@ import queue
import re
import threading
from collections.abc import Iterable
from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
@ -56,7 +55,7 @@ def _process_future(
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
def __init__(self, tenant_id: str, voice: str, language: str | None = None):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
@ -72,8 +71,8 @@ class AppGeneratorTTSPublisher:
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event: Optional[AudioTrunk] = None
self.max_sentence = 2
self._last_audio_event: AudioTrunk | None = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
@ -110,17 +109,19 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
output = message.event.outputs.get("output", "")
if isinstance(output, str):
self.msg_text += output
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
if len(sentence_arr) >= min(self.max_sentence, 7):
self.max_sentence += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
if isinstance(text_tmp, str):
self.msg_text = text_tmp
else:
self.msg_text = ""

View File

@ -1,5 +1,5 @@
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union
from typing import Any, TextIO, Union
from pydantic import BaseModel
@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str:
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
def print_text(text: str, color: str | None = None, end: str = "", file: TextIO | None = None):
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
@ -34,10 +34,10 @@ def print_text(text: str, color: Optional[str] = None, end: str = "", file: Opti
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ""
color: str | None = ""
current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None:
def __init__(self, color: str | None = None):
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel):
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
) -> None:
):
"""Do nothing."""
if dify_config.DEBUG:
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@ -58,10 +58,10 @@ class DifyAgentCallbackHandler(BaseModel):
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> None:
message_id: str | None = None,
timer: Any | None = None,
trace_manager: TraceQueueManager | None = None,
):
"""If not the final action, print out observation."""
if dify_config.DEBUG:
print_text("\n[on_tool_end]\n", color=self.color)
@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel):
)
)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
"""Do nothing."""
if dify_config.DEBUG:
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(self, thought: str) -> None:
def on_agent_start(self, thought: str):
"""Run on agent start."""
if dify_config.DEBUG:
if thought:
@ -98,13 +98,21 @@ class DifyAgentCallbackHandler(BaseModel):
else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
def on_agent_finish(self, color: str | None = None, **kwargs: Any):
"""Run on agent end."""
if dify_config.DEBUG:
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text(
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
color=self.color,
)
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -19,14 +21,14 @@ class DatasetIndexToolCallbackHandler:
def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None:
):
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
self._user_id = user_id
self._invoke_from = invoke_from
def on_query(self, query: str, dataset_id: str) -> None:
def on_query(self, query: str, dataset_id: str):
"""
Handle query.
"""
@ -44,12 +46,13 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query)
db.session.commit()
def on_tool_end(self, documents: list[Document]) -> None:
def on_tool_end(self, documents: list[Document]):
"""Handle tool end."""
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -57,17 +60,14 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Iterable, Mapping
from typing import Any, Optional
from typing import Any
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
from core.ops.ops_trace_manager import TraceQueueManager
@ -14,9 +14,9 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Iterable[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
message_id: str | None = None,
timer: Any | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
for tool_output in tool_outputs:
print_text("\n[on_tool_execution]\n", color=self.color)

View File

@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
from configs import dify_config
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
icon: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
icon: str,
) -> None:
self.entity = entity
self.runtime = runtime
self.icon = icon
@abstractmethod
def datasource_provider_type(self) -> str:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
icon=self.icon,
)
def get_icon_url(self, tenant_id: str) -> str:
return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501

View File

@ -0,0 +1,118 @@
from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_datasource_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@abstractmethod
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
"""
return datasource with given name
"""
pass
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@ -0,0 +1,40 @@
from typing import TYPE_CHECKING, Any, Optional
from openai import BaseModel
from pydantic import Field
# Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
if TYPE_CHECKING:
from core.app.entities.app_invoke_entities import InvokeFrom
class DatasourceRuntime(BaseModel):
"""
Meta data of a datasource call processing
"""
tenant_id: str
datasource_id: str | None = None
invoke_from: Optional["InvokeFrom"] = None
datasource_invoke_from: DatasourceInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeDatasourceRuntime(DatasourceRuntime):
"""
Fake datasource runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id",
invoke_from=InvokeFrom.DEBUGGER,
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={},
runtime_parameters={},
)

View File

View File

@ -0,0 +1,218 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class DatasourceFileManager:
@staticmethod
def sign_file(datasource_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: str | None = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"datasources/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=present_filename,
size=len(file_binary),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(file_binary).hexdigest(),
source_url="",
created_at=datetime.now(),
)
db.session.add(upload_file)
db.session.commit()
db.session.refresh(upload_file)
return upload_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: str | None = None,
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
tenant_id=tenant_id,
user_id=user_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
if not upload_file:
return None
blob = storage.load_once(upload_file.key)
return blob, upload_file.mime_type
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_upload_file_id(upload_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
return None, None
stream = storage.load_stream(upload_file.key)
return stream, upload_file.mime_type
# init tool_file_parser
# from core.file.datasource_file_parser import datasource_file_manager
#
# datasource_file_manager["manager"] = DatasourceFileManager

View File

@ -0,0 +1,112 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
controller: DatasourcePluginProviderController | None = None
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.ONLINE_DRIVE:
controller = OnlineDriveDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
if controller:
datasource_plugin_providers[provider_id] = controller
if controller is None:
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)

View File

@ -0,0 +1,71 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
class DatasourceApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: dict | None = None
original_credentials: dict | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("datasources", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"datasources": datasources,
"labels": self.labels,
}

View File

@ -0,0 +1,21 @@
from pydantic import BaseModel, Field
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -0,0 +1,380 @@
import enum
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from yarl import URL
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum):
"""
Enum class for datasource provider
"""
ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class DatasourceParameter(PluginParameter):
"""
Overrides type
"""
class DatasourceParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: DatasourceParameterType = Field(..., description="The type of the parameter")
description: I18nObject = Field(..., description="The description of the parameter")
@classmethod
def get_simple_instance(
cls,
name: str,
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
type=typ,
required=required,
options=option_objs,
description=I18nObject(en_US="", zh_Hans=""),
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class DatasourceIdentity(BaseModel):
author: str = Field(..., description="The author of the datasource")
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: str | None = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: dict | None = None
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or []
class DatasourceProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: list[ToolLabelEnum] | None = Field(
default=[],
description="The tags of the tool",
)
def generate_datasource_icon_url(self, tenant_id: str) -> str:
HARD_CODED_DATASOURCE_ICONS = ["https://assets.dify.ai/images/File%20Upload.svg"]
if self.icon in HARD_CODED_DATASOURCE_ICONS:
return self.icon
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "plugin"
/ "icon"
% {"tenant_id": tenant_id, "filename": self.icon}
)
class DatasourceProviderEntity(BaseModel):
"""
Datasource provider entity
"""
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: OAuthSchema | None = None
provider_type: DatasourceProviderType
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
"""
Get an instance of DatasourceInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class DatasourceLabel(BaseModel):
"""
Datasource label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
"""
Enum class for datasource invoke
"""
RAG_PIPELINE = "rag_pipeline"
class OnlineDocumentPage(BaseModel):
"""
Online document page
"""
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: dict | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: str | None = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
"""
Online document info
"""
workspace_id: str | None = Field(None, description="The workspace id")
workspace_name: str | None = Field(None, description="The workspace name")
workspace_icon: str | None = Field(None, description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class OnlineDocumentPagesMessage(BaseModel):
"""
Get online document pages response
"""
result: list[OnlineDocumentInfo]
class GetOnlineDocumentPageContentRequest(BaseModel):
"""
Get online document page content request
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
type: str = Field(..., description="The type of the page")
class OnlineDocumentPageContent(BaseModel):
"""
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
class GetOnlineDocumentPageContentResponse(BaseModel):
"""
Get online document page content response
"""
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
"""
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
source_url: str = Field(..., description="The url of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
class WebSiteInfo(BaseModel):
"""
Website info
"""
status: str | None = Field(..., description="crawl job status")
web_info_list: list[WebSiteInfoDetail] | None = []
total: int | None = Field(default=0, description="The total number of websites")
completed: int | None = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
"""
Get website crawl response
"""
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage):
pass
#########################
# Online drive file
#########################
class OnlineDriveFile(BaseModel):
"""
Online drive file
"""
id: str = Field(..., description="The file ID")
name: str = Field(..., description="The file name")
size: int = Field(..., description="The file size")
type: str = Field(..., description="The file type: folder or file")
class OnlineDriveFileBucket(BaseModel):
"""
Online drive file bucket
"""
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
"""
Get online drive file list request
"""
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):
"""
Get online drive file list response
"""
result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets")
class OnlineDriveDownloadFileRequest(BaseModel):
"""
Get online drive file
"""
id: str = Field(..., description="The id of the file")
bucket: str | None = Field(None, description="The name of the bucket")

View File

@ -0,0 +1,37 @@
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
class DatasourceProviderNotFoundError(ValueError):
pass
class DatasourceNotFoundError(ValueError):
pass
class DatasourceParameterValidationError(ValueError):
pass
class DatasourceProviderCredentialValidationError(ValueError):
pass
class DatasourceNotSupportedError(ValueError):
pass
class DatasourceInvokeError(ValueError):
pass
class DatasourceApiSchemaError(ValueError):
pass
class DatasourceEngineInvokeError(Exception):
meta: DatasourceInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@ -0,0 +1,29 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
def get_icon_url(self, tenant_id: str) -> str:
return self.icon

View File

@ -0,0 +1,56 @@
from typing import Any
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return LocalFileDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -0,0 +1,71 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
def get_online_document_pages(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_pages(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def get_online_document_page_content(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_page_content(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -0,0 +1,48 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -0,0 +1,71 @@
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
tenant_id: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
def online_drive_browse_files(
self,
user_id: str,
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_browse_files(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def online_drive_download_file(
self,
user_id: str,
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_download_file(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -0,0 +1,48 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DRIVE
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDriveDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

View File

@ -0,0 +1,127 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class DatasourceFileMessageTransformer:
@classmethod
def transform_datasource_invoke_messages(
cls,
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: str | None = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
"""
for message in messages:
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
yield message
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
message.message, DatasourceMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
tool_file_manager = ToolFileManager()
tool_file: ToolFile | None = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
if tool_file:
url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
# get filename from meta
filename = meta.get("file_name", None)
mimetype = meta.get("mime_type")
if not mimetype:
mimetype = (guess_type(filename)[0] if filename else None) or "application/octet-stream"
# if message is str, encode it to bytes
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
if blob_tool_file:
url = cls.get_datasource_file_url(
datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype)
)
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file: File | None = meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
else:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -0,0 +1,51 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
WebsiteCrawlMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime, icon)
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
def get_website_crawl(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -0,0 +1,52 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceProviderEntityWithPlugin,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,8 +1,8 @@
from enum import Enum
from enum import StrEnum, auto
class PlanningStrategy(Enum):
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"
class PlanningStrategy(StrEnum):
ROUTER = auto()
REACT_ROUTER = auto()
REACT = auto()
FUNCTION_CALL = auto()

View File

@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum, auto
class EmbeddingInputType(Enum):
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = "document"
QUERY = "query"
DOCUMENT = auto()
QUERY = auto()

View File

@ -1,11 +1,9 @@
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
child_chunks: list[str] | None = None
class QAPreviewDetail(BaseModel):
@ -16,4 +14,28 @@ class QAPreviewDetail(BaseModel):
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
qa_preview: list[QAPreviewDetail] | None = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
name: str
indexing_status: str
error: str | None = None
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict
@ -9,16 +8,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ProviderEntity
class ModelStatus(Enum):
class ModelStatus(StrEnum):
"""
Enum class for model status.
"""
ACTIVE = "active"
ACTIVE = auto()
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
DISABLED = auto()
CREDENTIAL_REMOVED = "credential-removed"
@ -29,11 +28,11 @@ class SimpleModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
def __init__(self, provider_entity: ProviderEntity):
"""
Init simple provider.
@ -57,7 +56,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None:
def raise_for_status(self):
"""
Check model status and raise ValueError if not active.
@ -92,8 +91,8 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@ -1,20 +1,20 @@
from enum import StrEnum
from enum import StrEnum, auto
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SELECT = auto()
STRING = auto()
NUMBER = auto()
FILE = auto()
FILES = auto()
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
BOOLEAN = auto()
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
ANY = "any"
ANY = auto()
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
ALL = auto()
CHAT = auto()
WORKFLOW = auto()
COMPLETION = auto()
class ModelSelectorScope(StrEnum):
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
RERANK = auto()
TTS = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
VISION = auto()
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"
ALL = auto()
CUSTOM = auto()
BUILTIN = auto()
WORKFLOW = auto()

View File

@ -1,9 +1,9 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import func, select
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import (
@ -41,6 +40,8 @@ from models.provider import (
ProviderType,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)
@ -90,7 +91,7 @@ class ProviderConfiguration(BaseModel):
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
Get current credentials.
@ -128,18 +129,42 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
credentials = None
current_credential_id = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
current_credential_id = model_configuration.current_credential_id
break
if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
current_credential_id = self.custom_configuration.provider.current_credential_id
if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
else:
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
return credentials
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
"""
Get system configuration status.
:return:
@ -180,16 +205,10 @@ class ProviderConfiguration(BaseModel):
"""
Get custom provider record.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
Provider.provider_name.in_(self._get_provider_names()),
)
return session.execute(stmt).scalar_one_or_none()
@ -251,7 +270,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name,
)
if exclude_id:
@ -265,7 +284,6 @@ class ProviderConfiguration(BaseModel):
:param credential_id: if provided, return the specified credential
:return:
"""
if credential_id:
return self._get_specific_provider_credential(credential_id)
@ -279,9 +297,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict, credential_id: str = "", session: Session | None = None
) -> dict:
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
"""
Validate custom credentials.
:param credentials: provider credentials
@ -290,7 +306,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
@ -302,7 +318,7 @@ class ProviderConfiguration(BaseModel):
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
@ -343,7 +359,75 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -351,8 +435,11 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@ -395,8 +482,8 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
) -> None:
credential_name: str | None,
):
"""
update a saved provider credential (by credential_id).
@ -406,7 +493,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -418,7 +505,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -428,9 +515,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@ -457,7 +544,7 @@ class ProviderConfiguration(BaseModel):
credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str,
session: Session,
) -> None:
):
"""
Update load balancing configurations that reference the given credential_id.
@ -471,7 +558,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source,
)
@ -497,7 +584,7 @@ class ProviderConfiguration(BaseModel):
session.commit()
def delete_provider_credential(self, credential_id: str) -> None:
def delete_provider_credential(self, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
@ -508,7 +595,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -519,7 +606,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
@ -532,13 +619,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@ -547,7 +628,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@ -580,7 +661,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def switch_active_provider_credential(self, credential_id: str) -> None:
def switch_active_provider_credential(self, credential_id: str):
"""
Switch active provider credential (copy the selected one into current active snapshot).
@ -591,7 +672,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -627,6 +708,7 @@ class ProviderConfiguration(BaseModel):
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -659,7 +741,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -684,6 +766,7 @@ class ProviderConfiguration(BaseModel):
current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name
credentials = self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@ -705,7 +788,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
@ -714,9 +797,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> Optional[dict]:
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
"""
Get custom model credentials.
@ -738,6 +819,7 @@ class ProviderConfiguration(BaseModel):
):
current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name
credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@ -758,7 +840,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict,
credential_id: str = "",
session: Session | None = None,
) -> dict:
):
"""
Validate custom model credentials.
@ -769,7 +851,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
@ -782,7 +864,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -822,7 +904,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -833,10 +915,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@ -880,7 +967,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@ -893,7 +980,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@ -914,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -925,8 +1012,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@ -947,7 +1035,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
@ -958,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -968,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
@ -982,12 +1070,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -996,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1022,7 +1105,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
"""
if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record.
@ -1036,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1054,6 +1137,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@ -1064,7 +1148,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
switch the custom model credential.
@ -1077,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1094,7 +1178,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.
:param model_type: model type
@ -1124,14 +1208,9 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@ -1190,7 +1269,7 @@ class ProviderConfiguration(BaseModel):
return model_setting
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
"""
Get provider model setting.
:param model_type: model type
@ -1207,6 +1286,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1289,7 +1369,7 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
"""
Switch preferred provider type.
:param provider_type:
@ -1301,16 +1381,10 @@ class ProviderConfiguration(BaseModel):
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
def _switch(s: Session) -> None:
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
def _switch(s: Session):
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
)
preferred_model_provider = s.execute(stmt).scalars().first()
@ -1340,12 +1414,12 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.
@ -1366,7 +1440,7 @@ class ProviderConfiguration(BaseModel):
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
) -> ModelWithProviderEntity | None:
"""
Get provider model.
:param model_type: model type
@ -1383,7 +1457,7 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
@ -1567,7 +1641,7 @@ class ProviderConfiguration(BaseModel):
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
model: Optional[str] = None,
model: str | None = None,
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -1605,11 +1679,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@ -1631,6 +1703,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@ -1663,11 +1737,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED
@ -1703,7 +1775,7 @@ class ProviderConfigurations(BaseModel):
super().__init__(tenant_id=tenant_id)
def get_models(
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get available models.
@ -1760,8 +1832,14 @@ class ProviderConfigurations(BaseModel):
def __setitem__(self, key, value):
self.configurations[key] = value
def __contains__(self, key):
if "/" not in key:
key = str(ModelProviderID(key))
return key in self.configurations
def __iter__(self):
return iter(self.configurations)
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, Union
from enum import StrEnum, auto
from typing import Union
from pydantic import BaseModel, ConfigDict, Field
@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
@ -31,25 +31,25 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class QuotaUnit(StrEnum):
TIMES = auto()
TOKENS = auto()
CREDITS = auto()
class SystemConfigurationStatus(Enum):
class SystemConfigurationStatus(StrEnum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
ACTIVE = auto()
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
UNSUPPORTED = auto()
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
base_model_name: str | None = None
model_type: ModelType
# pydantic configs
@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: bool | None = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@ -134,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
name: str
credentials: dict
credential_source_type: str | None = None
credential_id: str | None = None
class ModelSettings(BaseModel):
@ -144,6 +156,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs
@ -155,14 +168,14 @@ class BasicProviderConfig(BaseModel):
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
class Type(StrEnum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT
TEXT_INPUT = CommonParameterType.TEXT_INPUT
SELECT = CommonParameterType.SELECT
BOOLEAN = CommonParameterType.BOOLEAN
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
@ -192,13 +205,13 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str, float, bool, list]] = None
options: Optional[list[Option]] = None
default: Union[int, str, float, bool] | None = None
options: list[Option] | None = None
multiple: bool | None = False
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
label: I18nObject | None = None
help: I18nObject | None = None
url: str | None = None
placeholder: I18nObject | None = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)

View File

@ -1,12 +1,9 @@
from typing import Optional
class LLMError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: str | None = None):
self.description = description

View File

@ -10,11 +10,11 @@ class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
def __init__(self, api_endpoint: str, api_key: str):
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
def request(self, point: APIBasedExtensionPoint, params: dict):
"""
Request the api.
@ -43,9 +43,9 @@ class APIBasedExtensionRequestor:
timeout=self.timeout,
proxies=proxies,
)
except requests.exceptions.Timeout:
except requests.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
except requests.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@ -1,10 +1,10 @@
import enum
import importlib.util
import json
import logging
import os
from enum import StrEnum, auto
from pathlib import Path
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
@ -13,18 +13,18 @@ from core.helper.position_helper import sort_to_dict_by_position_map
logger = logging.getLogger(__name__)
class ExtensionModule(enum.Enum):
MODERATION = "moderation"
EXTERNAL_DATA_TOOL = "external_data_tool"
class ExtensionModule(StrEnum):
MODERATION = auto()
EXTERNAL_DATA_TOOL = auto()
class ModuleExtension(BaseModel):
extension_class: Optional[Any] = None
extension_class: Any | None = None
name: str
label: Optional[dict] = None
form_schema: Optional[list] = None
label: dict | None = None
form_schema: list | None = None
builtin: bool = True
position: Optional[int] = None
position: int | None = None
class Extensible:
@ -32,9 +32,9 @@ class Extensible:
name: str
tenant_id: str
config: Optional[dict] = None
config: dict | None = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
def __init__(self, tenant_id: str, config: dict | None = None):
self.tenant_id = tenant_id
self.config = config
@ -91,7 +91,7 @@ class Extensible:
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
for obj in vars(mod).values():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
@ -123,7 +123,7 @@ class Extensible:
)
)
except Exception as e:
except Exception:
logger.exception("Error scanning extensions")
raise

View File

@ -41,9 +41,3 @@ class Extension:
assert module_extension.extension_class is not None
t: type = module_extension.extension_class
return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

@ -1,4 +1,4 @@
from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
@ -16,7 +16,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.
@ -28,18 +28,16 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: Optional[str] = None) -> str:
def query(self, inputs: dict, query: str | None = None) -> str:
"""
Query the external data tool.
@ -52,13 +50,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(

Some files were not shown because too many files have changed in this diff Show More