from __future__ import annotations import base64 import io import json import logging import re import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.enums import ( NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from dify_graph.file import File, FileTransferMethod, FileType, file_manager from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, ) from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, LLMStructuredOutput, LLMUsage, ) from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, PromptMessageRole, SystemPromptMessage, UserPromptMessage, ) from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, ) from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import VariablePool from dify_graph.variables import ( ArrayFileSegment, ArraySegment, FileSegment, NoneSegment, ObjectSegment, StringSegment, ) from extensions.ext_database import db from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from . import llm_utils from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, ) from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, MemoryRolePrefixRequiredError, NoPromptFoundError, TemplateTypeNotSupportError, VariableNotFoundError, ) from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from dify_graph.file.models import File from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) class LLMNode(Node[LLMNodeData]): node_type = NodeType.LLM # Compiled regex for extracting blocks (with compatibility for attributes) _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list[File] _llm_file_saver: LLMFileSaver _credentials_provider: CredentialsProvider _model_factory: ModelFactory _model_instance: ModelInstance _memory: PromptMessageMemory | None def __init__( self, id: str, config: Mapping[str, Any], graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, credentials_provider: CredentialsProvider, model_factory: ModelFactory, model_instance: ModelInstance, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] self._credentials_provider = credentials_provider self._model_factory = model_factory self._model_instance = model_instance self._memory = memory if llm_file_saver is None: llm_file_saver = FileSaverImpl( user_id=graph_init_params.user_id, tenant_id=graph_init_params.tenant_id, ) self._llm_file_saver = llm_file_saver @classmethod def version(cls) -> str: return "1" def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} result_text = "" clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None reasoning_content = None variable_pool = self.graph_runtime_state.variable_pool try: # init messages template self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data=self.node_data) # fetch jinja2 inputs jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) # merge inputs inputs.update(jinja_inputs) # fetch files files = ( llm_utils.fetch_files( variable_pool=variable_pool, selector=self.node_data.vision.configs.variable_selector, ) if self.node_data.vision.enabled else [] ) if files: node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] for event in generator: context = event.context context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context if context_files: node_inputs["#context_files#"] = [file.model_dump() for file in context_files] # fetch model config model_instance = self._model_instance model_name = model_instance.model_name model_provider = model_instance.provider model_stop = model_instance.stop memory = self._memory query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, memory=memory, model_instance=model_instance, stop=model_stop, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, ) # handle invoke result generator = LLMNode.invoke_llm( model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, node_type=self.node_type, reasoning_format=self.node_data.reasoning_format, ) structured_output: LLMStructuredOutput | None = None for event in generator: if isinstance(event, StreamChunkEvent): yield event elif isinstance(event, ModelInvokeCompletedEvent): # Raw text result_text = event.text usage = event.usage finish_reason = event.finish_reason reasoning_content = event.reasoning_content or "" # For downstream nodes, determine clean text based on reasoning_format if self.node_data.reasoning_format == "tagged": # Keep tags for backward compatibility clean_text = result_text else: # Extract clean text from tags clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) # Process structured output if available from the event. structured_output = ( LLMStructuredOutput(structured_output=event.structured_output) if event.structured_output else None ) break elif isinstance(event, LLMStructuredOutput): structured_output = event process_data = { "model_mode": self.node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), "finish_reason": finish_reason, "model_provider": model_provider, "model_name": model_name, } outputs = { "text": clean_text, "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } if structured_output: outputs["structured_output"] = structured_output.structured_output if self._file_outputs: outputs["files"] = ArrayFileSegment(value=self._file_outputs) # Send final chunk event to indicate streaming is complete yield StreamChunkEvent( selector=[self._node_id, "text"], chunk="", is_final=True, ) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_data, outputs=outputs, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, llm_usage=usage, ) ) except ValueError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, llm_usage=usage, ) ) except Exception as e: logger.exception("error while executing llm node") yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, llm_usage=usage, ) ) @staticmethod def invoke_llm( *, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_parameters = model_instance.parameters invoke_model_parameters = dict(model_parameters) model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( structured_output=structured_output or {}, ) request_start_time = time.perf_counter() invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, json_schema=output_schema, model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, ) else: request_start_time = time.perf_counter() invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, ) return LLMNode.handle_invoke_result( invoke_result=invoke_result, file_saver=file_saver, file_outputs=file_outputs, node_id=node_id, node_type=node_type, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @staticmethod def handle_invoke_result( *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): duration = None if request_start_time is not None: duration = time.perf_counter() - request_start_time invoke_result.usage.latency = round(duration, 3) event = LLMNode.handle_blocking_result( invoke_result=invoke_result, saver=file_saver, file_outputs=file_outputs, reasoning_format=reasoning_format, request_latency=duration, ) yield event return # For streaming mode model = "" prompt_messages: list[PromptMessage] = [] usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() # Initialize streaming metrics tracking start_time = request_start_time if request_start_time is not None else time.perf_counter() first_token_time = None has_content = False collected_structured_output = None # Collect structured_output from streaming chunks # Consume the invoke result and handle generator exception try: for result in invoke_result: if isinstance(result, LLMResultChunkWithStructuredOutput): # Collect structured_output from the chunk if result.structured_output is not None: collected_structured_output = dict(result.structured_output) yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( contents=contents, file_saver=file_saver, file_outputs=file_outputs, ): # Detect first token for TTFT calculation if text_part and not has_content: first_token_time = time.perf_counter() has_content = True full_text_buffer.write(text_part) yield StreamChunkEvent( selector=[node_id, "text"], chunk=text_part, is_final=False, ) # Update the whole metadata if not model and result.model: model = result.model if len(prompt_messages) == 0: # TODO(QuantumGhost): it seems that this update has no visable effect. # What's the purpose of the line below? prompt_messages = list(result.prompt_messages) if usage.prompt_tokens == 0 and result.delta.usage: usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason except OutputParserError as e: raise LLMNodeError(f"Failed to parse structured output: {e}") # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() if reasoning_format == "tagged": # Keep tags in text for backward compatibility clean_text = full_text reasoning_content = "" else: # Extract clean text and reasoning from tags clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) # Calculate streaming metrics end_time = time.perf_counter() total_duration = end_time - start_time usage.latency = round(total_duration, 3) if has_content and first_token_time: gen_ai_server_time_to_first_token = first_token_time - start_time llm_streaming_time_to_generate = end_time - first_token_time usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) usage.time_to_generate = round(llm_streaming_time_to_generate, 3) yield ModelInvokeCompletedEvent( # Use clean_text for separated mode, full_text for tagged mode text=clean_text if reasoning_format == "separated" else full_text, usage=usage, finish_reason=finish_reason, # Reasoning content for workflow variables and downstream nodes reasoning_content=reasoning_content, # Pass structured output if collected from streaming chunks structured_output=collected_structured_output, ) @staticmethod def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @classmethod def _split_reasoning( cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" ) -> tuple[str, str]: """ Split reasoning content from text based on reasoning_format strategy. Args: text: Full text that may contain blocks reasoning_format: Strategy for handling reasoning content - "separated": Remove tags and return clean text + reasoning_content field - "tagged": Keep tags in text, return empty reasoning_content Returns: tuple of (clean_text, reasoning_content) """ if reasoning_format == "tagged": return text, "" # Find all ... blocks (case-insensitive) matches = cls._THINK_PATTERN.findall(text) # Extract reasoning content from all blocks reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" # Remove all ... blocks from original text clean_text = cls._THINK_PATTERN.sub("", text) # Clean up extra whitespace clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() # Separated mode: always return clean text and reasoning_content return clean_text, reasoning_content or "" def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables: dict[str, Any] = {} if not node_data.prompt_config: return variables for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ Parse dict into string """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: return str(input_dict["content"]) # else, parse the dict try: return json.dumps(input_dict, ensure_ascii=False) except Exception: return str(input_dict) if isinstance(variable, ArraySegment): result = "" for item in variable.value: if isinstance(item, dict): result += parse_dict(item) else: result += str(item) result += "\n" value = result.strip() elif isinstance(variable, ObjectSegment): value = parse_dict(variable.value) else: value = variable.text variables[variable_name] = value return variables def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: inputs = {} prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) elif isinstance(prompt_template, CompletionModelPromptTemplate): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() memory = node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() return inputs def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return if not node_data.context.variable_selector: return context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): yield RunRetrieverResourceEvent( retriever_resources=[], context=context_value_variable.value, context_files=[] ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" else: if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") if item.get("summary"): context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) .where( SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, ) ).all() if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, related_id=upload_file.id, size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), ) context_files.append(attachment_info) yield RunRetrieverResourceEvent( retriever_resources=original_retriever_resource, context=context_str.strip(), context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: if ( "metadata" in context_dict and "_source" in context_dict["metadata"] and context_dict["metadata"]["_source"] == "knowledge" ): metadata = context_dict.get("metadata", {}) source = RetrievalSourceMetadata( position=metadata.get("position"), dataset_id=metadata.get("dataset_id"), dataset_name=metadata.get("dataset_name"), document_id=metadata.get("document_id"), document_name=metadata.get("document_name"), data_source_type=metadata.get("data_source_type"), segment_id=metadata.get("segment_id"), retriever_from=metadata.get("retriever_from"), score=metadata.get("score"), hit_count=metadata.get("segment_hit_count"), word_count=metadata.get("segment_word_count"), segment_position=metadata.get("segment_position"), index_node_hash=metadata.get("segment_index_node_hash"), content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), summary=context_dict.get("summary"), ) return source return None @staticmethod def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], context: str | None = None, memory: PromptMessageMemory | None = None, model_instance: ModelInstance, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if isinstance(prompt_template, list): # For chat model prompt_messages.extend( LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, variable_pool=variable_pool, vision_detail_config=vision_detail, ) ) # Get memory messages for chat mode memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, ) # Extend prompt_messages with memory messages prompt_messages.extend(memory_messages) # Add current query to the prompt messages if sys_query: message = LLMNodeChatModelMessage( text=sys_query, role=PromptMessageRole.USER, edition_type="basic", ) prompt_messages.extend( LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], variable_pool=variable_pool, vision_detail_config=vision_detail, ) ) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): # For completion model prompt_messages.extend( _handle_completion_template( template=prompt_template, context=context, jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) ) # Get memory text for completion model memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, ) # Insert histories into the prompt prompt_content = prompt_messages[0].content # For issue #11247 - Check if prompt content is a string or a list if isinstance(prompt_content, str): prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content elif isinstance(prompt_content, list): for content_item in prompt_content: if isinstance(content_item, TextPromptMessageContent): if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: content_item.data = memory_text + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") # Add current query to the prompt message if sys_query: if isinstance(prompt_content, str): prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content elif isinstance(prompt_content, list): for content_item in prompt_content: if isinstance(content_item, TextPromptMessageContent): content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") else: raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) # The sys_files will be deprecated later if vision_enabled and sys_files: file_prompts = [] for file in sys_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) # If last prompt is a user prompt, add files into its contents, # otherwise append a new user prompt if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) and isinstance(prompt_messages[-1].content, list) ): prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) # The context_files if vision_enabled and context_files: file_prompts = [] for file in context_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) # If last prompt is a user prompt, add files into its contents, # otherwise append a new user prompt if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) and isinstance(prompt_messages[-1].content, list) ): prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): prompt_message_content: list[PromptMessageContentUnionTypes] = [] for content_item in prompt_message.content: # Skip content if features are not defined if not model_schema.features: if content_item.type != PromptMessageContentType.TEXT: continue prompt_message_content.append(content_item) continue # Skip content if corresponding feature is not supported if ( ( content_item.type == PromptMessageContentType.IMAGE and ModelFeature.VISION not in model_schema.features ) or ( content_item.type == PromptMessageContentType.DOCUMENT and ModelFeature.DOCUMENT not in model_schema.features ) or ( content_item.type == PromptMessageContentType.VIDEO and ModelFeature.VIDEO not in model_schema.features ) or ( content_item.type == PromptMessageContentType.AUDIO and ModelFeature.AUDIO not in model_schema.features ) ): continue prompt_message_content.append(content_item) if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data else: prompt_message.content = prompt_message_content if prompt_message.is_empty(): continue filtered_prompt_messages.append(prompt_message) if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type _ = graph_config # Explicitly mark as unused # Create typed NodeData from dict typed_node_data = LLMNodeData.model_validate(node_data) prompt_template = typed_node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector memory = typed_node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if typed_node_data.context.enabled: variable_mapping["#context#"] = typed_node_data.context.variable_selector if typed_node_data.vision.enabled: variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector if typed_node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] if typed_node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type == "jinja2": enable_jinja = True else: for prompt in prompt_template: if prompt.edition_type == "jinja2": enable_jinja = True break if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "llm", "config": { "prompt_templates": { "chat_model": { "prompts": [ {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here are the chat histories between human and assistant, inside " " XML tags.\n\n\n{{" "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", "edition_type": "basic", }, "stop": ["Human:"], }, } }, } @staticmethod def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: if message.edition_type == "jinja2": result_text = _render_jinja2_message( template=message.jinja2_text or "", jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=result_text)], role=message.role ) prompt_messages.append(prompt_message) else: # Get segment group from basic message if context: template = message.text.replace("{#context#}", context) else: template = message.text segment_group = variable_pool.convert_template(template) # Process segments for images file_contents = [] for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) file_contents.append(file_content) elif isinstance(segment, FileSegment): file = segment.value if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) file_contents.append(file_content) # Create message with text from all segments plain_text = segment_group.text if plain_text: prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=plain_text)], role=message.role ) prompt_messages.append(prompt_message) if file_contents: # Create message with image contents prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) prompt_messages.append(prompt_message) return prompt_messages @staticmethod def handle_blocking_result( *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( contents=invoke_result.message.content, file_saver=saver, file_outputs=file_outputs, ): buffer.write(text_part) # Extract reasoning content from tags in the main text full_text = buffer.getvalue() if reasoning_format == "tagged": # Keep tags in text for backward compatibility clean_text = full_text reasoning_content = "" else: # Extract clean text and reasoning from tags clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) event = ModelInvokeCompletedEvent( # Use clean_text for separated mode, full_text for tagged mode text=clean_text if reasoning_format == "separated" else full_text, usage=invoke_result.usage, finish_reason=None, # Reasoning content for workflow variables and downstream nodes reasoning_content=reasoning_content, # Pass structured output if enabled structured_output=getattr(invoke_result, "structured_output", None), ) if request_latency is not None: event.usage.latency = round(request_latency, 3) return event @staticmethod def save_multimodal_image_output( *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, ) -> File: """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: - Inlined data encoded in base64, which would be saved to storage directly. - Remote files referenced by an url, which would be downloaded and then saved to storage. Currently, only image files are supported. """ if content.url != "": saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) return saved_file @staticmethod def fetch_structured_output_schema( *, structured_output: Mapping[str, Any], ) -> dict[str, Any]: """ Fetch the structured output schema from the node data. Returns: dict[str, Any]: The structured output schema """ if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") try: schema = json.loads(structured_output_schema) if not isinstance(schema, dict): raise LLMNodeError("structured_output_schema must be a JSON object") return schema except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, file_outputs: list[File], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. If the messages contain non-textual content (e.g., multimedia like images or videos), it will be saved separately, and the corresponding Markdown representation will be yielded to the caller. """ # NOTE(QuantumGhost): This function should yield results to the caller immediately # whenever new content or partial content is available. Avoid any intermediate buffering # of results. Additionally, do not yield empty strings; instead, yield from an empty list # if necessary. if contents is None: yield from [] return if isinstance(contents, str): yield contents else: for item in contents: if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): file = LLMNode.save_multimodal_image_output( content=item, file_saver=file_saver, ) file_outputs.append(file) yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) @property def retry(self) -> bool: return self.node_data.retry_config.retry_enabled @property def model_instance(self) -> ModelInstance: return self._model_instance def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: return UserPromptMessage(content=contents) case PromptMessageRole.ASSISTANT: return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) case _: raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( *, template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ): if not template: return "" jinja2_inputs = {} for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" code_execute_resp = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=template, inputs=jinja2_inputs, ) result_text = code_execute_resp["result"] return result_text def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_instance: ModelInstance, ) -> int: rest_tokens = 2000 runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in runtime_model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( runtime_model_parameters.get(parameter_rule.name) or runtime_model_parameters.get(str(parameter_rule.use_template)) or 0 ) rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens def _handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, ) memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, ) return memory_messages def _handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, ) -> str: memory_text = "" # Get history text from memory for completion model if memory and memory_config: rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") memory_text = llm_utils.fetch_memory_text( memory=memory, max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, ) return memory_text def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: """Handle completion template processing outside of LLMNode class. Args: template: The completion model prompt template context: Optional context string jinja2_variables: Variables for jinja2 template rendering variable_pool: Variable pool for template conversion Returns: Sequence of prompt messages """ prompt_messages = [] if template.edition_type == "jinja2": result_text = _render_jinja2_message( template=template.jinja2_text or "", jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) else: if context: template_text = template.text.replace("{#context#}", context) else: template_text = template.text result_text = variable_pool.convert_template(template_text).text prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER ) prompt_messages.append(prompt_message) return prompt_messages