mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge branch 'main' into feat/mcp
This commit is contained in:
@ -3,6 +3,8 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
@ -18,6 +20,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
@ -28,7 +31,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models.model import Conversation
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
@ -326,15 +328,12 @@ class AgentNode(ToolNode):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
@ -362,8 +361,10 @@ class AgentNode(ToolNode):
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features:
|
||||
if feature.value not in AgentOldVersionModelFeatures:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -26,7 +26,7 @@ class ParamsAutoGenerated(Enum):
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(Enum):
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
"""
|
||||
Enum class for old SDK version llm feature.
|
||||
"""
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, cast
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import (
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
|
||||
@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
|
||||
@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -167,8 +167,11 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, dict) and all(
|
||||
value is None or isinstance(value, dict) for value in output_value
|
||||
elif (
|
||||
isinstance(first_element, dict)
|
||||
and all(value is None or isinstance(value, dict) for value in output_value)
|
||||
or isinstance(first_element, list)
|
||||
and all(value is None or isinstance(value, list) for value in output_value)
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
if value is not None:
|
||||
|
||||
@ -26,9 +26,9 @@ from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
|
||||
@ -235,6 +235,10 @@ class Executor:
|
||||
files[key].append(file_tuple)
|
||||
|
||||
# convert files to list for httpx request
|
||||
# If there are no actual files, we still need to force httpx to use `multipart/form-data`.
|
||||
# This is achieved by inserting a harmless placeholder file that will be ignored by the server.
|
||||
if not files:
|
||||
self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))]
|
||||
if files:
|
||||
self.files = []
|
||||
for key, file_tuples in files.items():
|
||||
@ -373,7 +377,10 @@ class Executor:
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
body_string = ""
|
||||
if self.files:
|
||||
# Only log actual files if present.
|
||||
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
|
||||
# This prevents logging meaningless placeholder entries.
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for key, (filename, content, mime_type) in self.files:
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
|
||||
@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeData,
|
||||
|
||||
@ -4,12 +4,12 @@ from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
|
||||
@ -12,10 +12,10 @@ from flask import Flask, current_app, has_request_context
|
||||
from configs import dify_config
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
@ -37,7 +37,6 @@ from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -249,8 +248,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
metadata={
|
||||
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
@ -361,16 +360,16 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
|
||||
iter_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
if parallel_mode_run_id:
|
||||
# for parallel, the specific branch ID is more important than the sequential index
|
||||
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
|
||||
if event.route_node_state.node_run_result:
|
||||
current_metadata = event.route_node_state.node_run_result.metadata or {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
|
||||
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
return event
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
@ -41,7 +43,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
||||
@ -85,30 +86,31 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
if self.tenant_id:
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with Session(db.engine) as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
session.commit()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
||||
model_instance, model_config = self.get_model_config(metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
||||
"""
|
||||
Fetch model config
|
||||
:param model: model
|
||||
:return:
|
||||
"""
|
||||
if model is None:
|
||||
raise ValueError("model is required")
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
|
||||
@ -4,9 +4,9 @@ from typing import Any, Literal, Union
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output_enabled: bool = False
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
||||
@ -7,12 +7,12 @@ from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
@ -53,9 +54,10 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -70,14 +72,12 @@ from core.workflow.nodes.event import (
|
||||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
SupportStructuredOutputStatus,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
@ -267,14 +267,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except LLMNodeError as e:
|
||||
except ValueError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -302,8 +302,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
@ -474,7 +472,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + "\n"
|
||||
@ -492,7 +490,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||
)
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict):
|
||||
if (
|
||||
"metadata" in context_dict
|
||||
and "_source" in context_dict["metadata"]
|
||||
@ -500,24 +498,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
):
|
||||
metadata = context_dict.get("metadata", {})
|
||||
|
||||
source = {
|
||||
"position": metadata.get("position"),
|
||||
"dataset_id": metadata.get("dataset_id"),
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
"hit_count": metadata.get("segment_hit_count"),
|
||||
"word_count": metadata.get("segment_word_count"),
|
||||
"segment_position": metadata.get("segment_position"),
|
||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
||||
"content": context_dict.get("content"),
|
||||
"page": metadata.get("page"),
|
||||
"doc_metadata": metadata.get("doc_metadata"),
|
||||
}
|
||||
source = RetrievalSourceMetadata(
|
||||
position=metadata.get("position"),
|
||||
dataset_id=metadata.get("dataset_id"),
|
||||
dataset_name=metadata.get("dataset_name"),
|
||||
document_id=metadata.get("document_id"),
|
||||
document_name=metadata.get("document_name"),
|
||||
data_source_type=metadata.get("data_source_type"),
|
||||
segment_id=metadata.get("segment_id"),
|
||||
retriever_from=metadata.get("retriever_from"),
|
||||
score=metadata.get("score"),
|
||||
hit_count=metadata.get("segment_hit_count"),
|
||||
word_count=metadata.get("segment_word_count"),
|
||||
segment_position=metadata.get("segment_position"),
|
||||
index_node_hash=metadata.get("segment_index_node_hash"),
|
||||
content=context_dict.get("content"),
|
||||
page=metadata.get("page"),
|
||||
doc_metadata=metadata.get("doc_metadata"),
|
||||
)
|
||||
|
||||
return source
|
||||
|
||||
@ -526,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if model_schema.support_structure_output:
|
||||
node_data_model.completion_params = self._handle_native_json_schema(
|
||||
node_data_model.completion_params, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
@ -602,15 +588,11 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
@ -789,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=self.node_data.model.provider,
|
||||
model=self.node_data.model.name,
|
||||
)
|
||||
model_schema = model.model_type_instance.get_model_schema(
|
||||
model=self.node_data.model.name,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not model_schema.support_structure_output:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
@ -846,20 +840,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -902,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
@ -1184,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||
"""
|
||||
Check if the current model supports structured output.
|
||||
|
||||
Returns:
|
||||
SupportStructuredOutput: The support status of structured output
|
||||
"""
|
||||
# Early return if structured output is disabled
|
||||
if (
|
||||
not isinstance(self.node_data, LLMNodeData)
|
||||
or not self.node_data.structured_output_enabled
|
||||
or not self.node_data.structured_output
|
||||
):
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
# Get model schema and check if it exists
|
||||
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||
if not model_schema:
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
|
||||
# Check if model supports structured output feature
|
||||
return (
|
||||
SupportStructuredOutputStatus.SUPPORTED
|
||||
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
|
||||
@ -15,7 +15,8 @@ from core.variables import (
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
)
|
||||
@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
),
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
@ -331,7 +333,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
|
||||
graph_engine.graph_runtime_state.total_tokens
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"check_break_result": True}
|
||||
@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if NodeRunMetadataKey.LOOP_ID not in metadata:
|
||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.LOOP_ID: self.node_id,
|
||||
NodeRunMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
|
||||
@ -25,13 +25,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
@ -244,9 +243,9 @@ class ParameterExtractorNode(LLMNode):
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
@ -816,7 +813,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# FIXME: fix the type error later
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
if node_data.instruction:
|
||||
|
||||
@ -10,7 +10,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
@ -20,7 +21,6 @@ from core.workflow.nodes.llm import (
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .exc import InvalidModelTypeError
|
||||
@ -79,9 +79,13 @@ class QuestionClassifierNode(LLMNode):
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
sys_files=files,
|
||||
@ -142,9 +146,9 @@ class QuestionClassifierNode(LLMNode):
|
||||
outputs=outputs,
|
||||
edge_source_handle=category_id,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
@ -154,9 +158,9 @@ class QuestionClassifierNode(LLMNode):
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class StartNode(BaseNode[StartNodeData]):
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import Any, Optional
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to transform tool message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in NodeRunMetadataKey.__members__.values()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel):
|
||||
Group.
|
||||
"""
|
||||
|
||||
output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
output_type: SegmentType
|
||||
variables: list[list[str]]
|
||||
group_name: str
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
|
||||
Reference in New Issue
Block a user