mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
|
||||
@ -252,7 +252,10 @@ class AgentNode(Node):
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
@ -266,7 +269,7 @@ class AgentNode(Node):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
@ -417,7 +420,7 @@ class AgentNode(Node):
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
@ -476,7 +479,7 @@ class AgentNode(Node):
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
|
||||
@ -75,11 +75,11 @@ class DatasourceNode(Node):
|
||||
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
@ -267,7 +267,7 @@ class DatasourceNode(Node):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ class HttpRequestNode(Node):
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
||||
@ -95,7 +95,7 @@ class IterationNode(Node):
|
||||
"config": {
|
||||
"is_parallel": False,
|
||||
"parallel_nums": 10,
|
||||
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
|
||||
"error_handle_mode": ErrorHandleMode.TERMINATED,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
||||
@ -27,7 +27,7 @@ from .exc import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
@ -77,7 +77,7 @@ class KnowledgeIndexNode(Node):
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
|
||||
@ -72,7 +72,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
||||
@ -92,7 +92,7 @@ def fetch_memory(
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
@ -143,7 +143,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
|
||||
@ -945,7 +945,7 @@ class LLMNode(Node):
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
|
||||
if typed_node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
||||
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
@ -224,7 +224,7 @@ class ToolNode(Node):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
|
||||
@ -227,7 +227,7 @@ class WorkflowEntry:
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.START.value,
|
||||
"type": NodeType.START,
|
||||
"title": "Start",
|
||||
"desc": "Start",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user