import re from collections.abc import Mapping, Sequence from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from core.agent.entities import AgentLog, AgentResult from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType from dify_graph.entities import ToolCall, ToolCallResult from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.file import File from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import AgentLogEvent from dify_graph.nodes.base.entities import VariableSelector class ModelConfig(BaseModel): provider: str name: str mode: LLMMode completion_params: dict[str, Any] = Field(default_factory=dict) class ContextConfig(BaseModel): enabled: bool variable_selector: list[str] | None = None class VisionConfigOptions(BaseModel): variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH class VisionConfig(BaseModel): enabled: bool = False configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) @field_validator("configs", mode="before") @classmethod def convert_none_configs(cls, v: Any): if v is None: return VisionConfigOptions() return v class PromptConfig(BaseModel): jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) @field_validator("jinja2_variables", mode="before") @classmethod def convert_none_jinja2_variables(cls, v: Any): if v is None: return [] return v class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" jinja2_text: str | None = None skill: bool = False metadata: Mapping[str, Any] | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): jinja2_text: str | None = None class ToolMetadata(BaseModel): """ Tool metadata for LLM node with tool support. Defines the essential fields needed for tool configuration, particularly the 'type' field to identify tool provider type. """ # Core fields enabled: bool = True type: ToolProviderType = Field( default=ToolProviderType.BUILT_IN, description="Tool provider type: builtin, api, mcp, workflow" ) provider_name: str = Field(..., description="Tool provider name/identifier") tool_name: str = Field(..., description="Tool name") # Optional fields plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools") credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication") # Configuration fields parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters") settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration") extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") class ModelTraceSegment(BaseModel): """Model invocation trace segment with token usage and output.""" text: str | None = Field(None, description="Model output text content") reasoning: str | None = Field(None, description="Reasoning/thought content from model") tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model") @field_serializer("tool_calls") @classmethod def serialize_tool_calls(cls, tool_calls: list[ToolCall]) -> list[dict[str, Any]]: """Serialize tool_calls excluding icon fields.""" return [tc.model_dump(exclude={"icon", "icon_dark"}) for tc in tool_calls] class ToolTraceSegment(BaseModel): """Tool invocation trace segment with call details and result.""" id: str | None = Field(default=None, description="Unique identifier for this tool call") name: str | None = Field(default=None, description="Name of the tool being called") arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") output: str | None = Field(default=None, description="Tool call result") class LLMTraceSegment(BaseModel): """ Streaming trace segment for LLM tool-enabled runs. Represents alternating model and tool invocations in sequence: model -> tool -> model -> tool -> ... Each segment records its execution duration. """ type: Literal["model", "tool"] duration: float = Field(..., description="Execution duration in seconds") usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call") output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment") # Common metadata for both model and tool segments provider: str | None = Field(default=None, description="Model or tool provider identifier") name: str | None = Field(default=None, description="Name of the model or tool") icon: str | dict[str, Any] | None = Field(default=None, description="Icon for the provider") icon_dark: str | dict[str, Any] | None = Field(default=None, description="Dark theme icon for the provider") error: str | None = Field(default=None, description="Error message if segment failed") status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status") class LLMGenerationData(BaseModel): """Generation data from LLM invocation with tools. For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2 - reasoning_contents: [thought1, thought2, ...] - one element per turn - tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results """ text: str = Field(..., description="Accumulated text content from all turns") reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") usage: LLMUsage = Field(..., description="LLM usage statistics") finish_reason: str | None = Field(None, description="Finish reason from LLM") files: list[File] = Field(default_factory=list, description="Generated files") trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") class ThinkTagStreamParser: """Lightweight state machine to split streaming chunks by tags.""" _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) _END_PATTERN = re.compile(r"", re.IGNORECASE) _START_PREFIX = " int: """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" max_len = min(len(text), len(prefix) - 1) for i in range(max_len, 0, -1): if text[-i:].lower() == prefix[:i].lower(): return i return 0 def process(self, chunk: str) -> list[tuple[str, str]]: """ Split incoming chunk into ('thought' | 'text', content) tuples. Content excludes the tags themselves and handles split tags across chunks. """ parts: list[tuple[str, str]] = [] self._buffer += chunk while self._buffer: if self._in_think: end_match = self._END_PATTERN.search(self._buffer) if end_match: thought_text = self._buffer[: end_match.start()] if thought_text: parts.append(("thought", thought_text)) parts.append(("thought_end", "")) self._buffer = self._buffer[end_match.end() :] self._in_think = False continue hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) emit = self._buffer[: len(self._buffer) - hold_len] if emit: parts.append(("thought", emit)) self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" break start_match = self._START_PATTERN.search(self._buffer) if start_match: prefix = self._buffer[: start_match.start()] if prefix: parts.append(("text", prefix)) self._buffer = self._buffer[start_match.end() :] parts.append(("thought_start", "")) self._in_think = True continue hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) emit = self._buffer[: len(self._buffer) - hold_len] if emit: parts.append(("text", emit)) self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" break cleaned_parts: list[tuple[str, str]] = [] for kind, content in parts: # Extra safeguard: strip any stray tags that slipped through. content = self._START_PATTERN.sub("", content) content = self._END_PATTERN.sub("", content) if content or kind in {"thought_start", "thought_end"}: cleaned_parts.append((kind, content)) return cleaned_parts def flush(self) -> list[tuple[str, str]]: """Flush remaining buffer when the stream ends.""" if not self._buffer: return [] kind = "thought" if self._in_think else "text" content = self._buffer # Drop dangling partial tags instead of emitting them if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): content = "" self._buffer = "" if not content and not self._in_think: return [] # Strip any complete tags that might still be present. content = self._START_PATTERN.sub("", content) content = self._END_PATTERN.sub("", content) result: list[tuple[str, str]] = [] if content: result.append((kind, content)) if self._in_think: result.append(("thought_end", "")) self._in_think = False return result class StreamBuffers(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser) pending_thought: list[str] = Field(default_factory=list) pending_content: list[str] = Field(default_factory=list) pending_tool_calls: list[ToolCall] = Field(default_factory=list) current_turn_reasoning: list[str] = Field(default_factory=list) reasoning_per_turn: list[str] = Field(default_factory=list) class TraceState(BaseModel): trace_segments: list[LLMTraceSegment] = Field(default_factory=list) tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict) tool_call_index_map: dict[str, int] = Field(default_factory=dict) model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment") model_start_emitted: bool = Field(default=False, description="Whether model_start has been emitted for this turn") pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment") class AggregatedResult(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) text: str = "" files: list[File] = Field(default_factory=list) usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) finish_reason: str | None = None class AgentContext(BaseModel): agent_logs: list[AgentLogEvent] = Field(default_factory=list) agent_result: AgentResult | None = None class ToolOutputState(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) stream: StreamBuffers = Field(default_factory=StreamBuffers) trace: TraceState = Field(default_factory=TraceState) aggregate: AggregatedResult = Field(default_factory=AggregatedResult) agent: AgentContext = Field(default_factory=AgentContext) class ToolLogPayload(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) tool_name: str = "" tool_call_id: str = "" tool_args: dict[str, Any] = Field(default_factory=dict) tool_output: Any = None tool_error: Any = None files: list[Any] = Field(default_factory=list) meta: dict[str, Any] = Field(default_factory=dict) @classmethod def from_log(cls, log: AgentLog) -> "ToolLogPayload": data = log.data or {} return cls( tool_name=data.get("tool_name", ""), tool_call_id=data.get("tool_call_id", ""), tool_args=data.get("tool_args") or {}, tool_output=data.get("output"), tool_error=data.get("error"), files=data.get("files") or [], meta=data.get("meta") or {}, ) @classmethod def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload": return cls( tool_name=data.get("tool_name", ""), tool_call_id=data.get("tool_call_id", ""), tool_args=data.get("tool_args") or {}, tool_output=data.get("output"), tool_error=data.get("error"), files=data.get("files") or [], meta=data.get("meta") or {}, ) class PromptMessageContext(BaseModel): """Context variable reference in prompt template. YAML/JSON format: { "$context": ["node_id", "variable_name"] } This will be expanded to list[PromptMessage] at runtime. """ model_config = ConfigDict(populate_by_name=True) value_selector: Sequence[str] = Field(alias="$context") # Union type for prompt template items (static message or context variable reference) PromptTemplateItem: TypeAlias = Annotated[ LLMNodeChatModelMessage | PromptMessageContext, Field(discriminator=None), ] class ToolSetting(BaseModel): model_config = ConfigDict(extra="forbid") type: ToolProviderType provider: str tool_name: str enabled: bool = Field(default=True, description="Whether the tool is enabled") class LLMNodeData(BaseNodeData): type: NodeType = BuiltinNodeTypes.LLM model: ModelConfig prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") reasoning_format: Literal["separated", "tagged"] = Field( # Keep tagged as default for backward compatibility default="tagged", description=( """ Strategy for handling model reasoning output. separated: Return clean text (without tags) + reasoning_content field. Recommended for new workflows. Enables safe downstream parsing and workflow variable access: {{#node_id.reasoning_content#}} tagged : Return original text (with tags) + reasoning_content field. Maintains full backward compatibility while still providing reasoning_content for workflow automation. Frontend thinking panels work as before. """ ), ) # Computer Use computer_use: bool = Field(default=False, description="Whether to use the computer use feature") # Tool support tools: Sequence[ToolMetadata] = Field(default_factory=list) tool_settings: Sequence[ToolSetting] = Field(default_factory=list) max_iterations: int | None = Field(default=100, description="Maximum number of iterations for the LLM node") @field_validator("prompt_config", mode="before") @classmethod def convert_none_prompt_config(cls, v: Any): if v is None: return PromptConfig() return v @property def structured_output_enabled(self) -> bool: return self.structured_output_switch_on and self.structured_output is not None