Merge branch 'main' into feat/r2

This commit is contained in:
jyong
2025-05-15 15:15:23 +08:00
1025 changed files with 17699 additions and 4959 deletions

View File

@ -36,7 +36,7 @@ class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list, description="node configs mapping (node id: node config)"
default_factory=dict, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)"

View File

@ -95,7 +95,12 @@ class StreamProcessor(ABC):
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue

View File

@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]):
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:

View File

@ -11,6 +11,7 @@ import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
import webvtt # type: ignore
import yaml # type: ignore
from docx.document import Document
from docx.oxml.table import CT_Tbl
@ -132,6 +133,10 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
return _extract_text_from_json(file_content)
case "application/x-yaml" | "text/yaml":
return _extract_text_from_yaml(file_content)
case "text/vtt":
return _extract_text_from_vtt(file_content)
case "text/properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
@ -139,7 +144,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt":
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
@ -165,6 +170,10 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
case ".vtt":
return _extract_text_from_vtt(file_content)
case ".properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
@ -214,8 +223,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
"""
from unstructured.partition.api import partition_via_api
if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
if not dify_config.UNSTRUCTURED_API_URL:
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
@ -226,7 +235,7 @@ def _extract_text_from_doc(file_content: bytes) -> str:
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
@ -462,3 +471,68 @@ def _extract_text_from_msg(file_content: bytes) -> str:
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
text = _extract_text_from_plain_text(vtt_bytes)
# remove bom
text = text.lstrip("\ufeff")
raw_results = []
for caption in webvtt.from_string(text):
raw_results.append((caption.voice, caption.text))
# Merge consecutive utterances by the same speaker
merged_results = []
if raw_results:
current_speaker, current_text = raw_results[0]
for i in range(1, len(raw_results)):
spk, txt = raw_results[i]
if spk == None:
merged_results.append((None, current_text))
continue
if spk == current_speaker:
# If it is the same speaker, merge the utterances (joined by space)
current_text += " " + txt
else:
# If the speaker changes, register the utterance so far and move on
merged_results.append((current_speaker, current_text))
current_speaker, current_text = spk, txt
# Add the last element
merged_results.append((current_speaker, current_text))
else:
merged_results = raw_results
# Return the result in the specified format: Speaker "text" style
formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
return "\n".join(formatted)
def _extract_text_from_properties(file_content: bytes) -> str:
try:
text = _extract_text_from_plain_text(file_content)
lines = text.splitlines()
result = []
for line in lines:
line = line.strip()
# Preserve comments and empty lines
if not line or line.startswith("#") or line.startswith("!"):
result.append(line)
continue
if "=" in line:
key, value = line.split("=", 1)
elif ":" in line:
key, value = line.split(":", 1)
else:
key, value = line, ""
result.append(f"{key.strip()}: {value.strip()}")
return "\n".join(result)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e

View File

@ -262,7 +262,10 @@ class Executor:
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.auth.config.type == "basic":
credentials = authorization.config.api_key
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
if ":" in credentials:
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
else:
encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""

View File

@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file = ToolFileManager.create_file_by_raw(
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,

View File

@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]):
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
return event
iter_metadata = {
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
current_metadata = event.route_node_state.node_run_result.metadata or {}
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event
def _run_single_iter(

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -32,11 +32,11 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -264,6 +264,7 @@ class KnowledgeRetrievalNode(LLMNode):
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
@ -275,12 +276,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"metadata": {
@ -289,7 +294,7 @@ class KnowledgeRetrievalNode(LLMNode):
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"document_data_source_type": document.data_source_type,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
@ -356,12 +361,12 @@ class KnowledgeRetrievalNode(LLMNode):
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
conditions = []
if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
@ -372,13 +377,24 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else:
raise ValueError("Invalid expected metadata value type")
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
@ -493,24 +509,24 @@ class KnowledgeRetrievalNode(LLMNode):
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
case "" | "<=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters
@ -618,7 +634,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(

View File

@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which companys email address test@examp
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -0,0 +1,160 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
class LLMFileSaver(tp.Protocol):
"""LLMFileSaver is responsible for save multimodal output returned by
LLM.
"""
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
pass
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
Currently (2025-04-30), no model returns multimodel output as a url.
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=data,
mimetype=mime_type,
)
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

View File

@ -1,3 +1,5 @@
import base64
import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -95,9 +96,13 @@ from .exc import (
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
@ -106,8 +111,45 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
return None
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
except Exception as e:
logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
event = self._handle_blocking_result(invoke_result=invoke_result)
yield event
return
model = None
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
for result in invoke_result:
text = convert_llm_result_chunk_to_str(result.delta.message.content)
full_text += text
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if not model:
# Update the whole metadata
if not model and result.model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
if len(prompt_messages) == 0:
# TODO(QuantumGhost): it seems that this update has no visable effect.
# What's the purpose of the line below?
prompt_messages = list(result.prompt_messages)
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@ -459,7 +506,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("document_data_source_type"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
@ -750,18 +797,22 @@ class LLMNode(BaseNode[LLMNodeData]):
stop = model_config.stop
return filtered_prompt_messages, stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
structured_output: dict[str, Any] | list[Any] = {}
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
try:
parsed = json.loads(result_text)
if not isinstance(parsed, (dict | list)):
if not isinstance(parsed, dict):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except json.JSONDecodeError as e:
# if the result_text is not a valid json, try to repair it
parsed = json_repair.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
if not isinstance(parsed, dict):
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
if isinstance(parsed, list):
parsed = next((item for item in parsed if isinstance(item, dict)), {})
else:
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
return structured_output
@ -963,6 +1014,42 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
usage=invoke_result.usage,
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
- Inlined data encoded in base64, which would be saved to storage directly.
- Remote files referenced by an url, which would be downloaded and then saved to storage.
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
@ -1123,6 +1210,41 @@ class LLMNode(BaseNode[LLMNodeData]):
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
If the messages contain non-textual content (e.g., multimedia like images or videos),
it will be saved separately, and the corresponding Markdown representation will
be yielded to the caller.
"""
# NOTE(QuantumGhost): This function should yield results to the caller immediately
# whenever new content or partial content is available. Avoid any intermediate buffering
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
# if necessary.
if contents is None:
yield from []
return
if isinstance(contents, str):
yield contents
elif isinstance(contents, list):
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
else:
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole

View File

@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None

View File

@ -337,7 +337,7 @@ class LoopNode(BaseNode[LoopNodeData]):
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield event
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,

View File

@ -17,7 +17,7 @@ Some additional information is provided below. Always adhere to these instructio
</instruction>
Steps:
1. Review the chat history provided within the <histories> tags.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
@ -89,13 +89,13 @@ Some extra information are provided below, I should always follow the instructio
</instructions>
### Extract parameter Workflow
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
<information to be extracted>
{{ structure }}
</information to be extracted>
Step 1: Carefully read the input and understand the structure of the expected output.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
@ -106,10 +106,10 @@ Here are the chat histories between human and assistant, inside <histories></his
</histories>
### Structure
Here is the structure of the expected output, I should always follow the output structure.
Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
@ -119,7 +119,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
</text>
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
""" # noqa: E501

View File

@ -55,7 +55,7 @@ You are a text classification engine that analyzes text data and assigns categor
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
@ -64,7 +64,7 @@ User:{{"input_text": ["I recently had a great experience with your company. The
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
</example>
</example>
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>

View File

@ -11,6 +11,8 @@ class Operation(StrEnum):
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
REMOVE_FIRST = "remove-first"
REMOVE_LAST = "remove-last"
class InputType(StrEnum):

View File

@ -23,6 +23,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
@ -51,7 +60,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}:
return True
match variable_type:
case SegmentType.STRING:

View File

@ -64,7 +64,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
@ -165,5 +165,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case Operation.REMOVE_FIRST:
# If array is empty, do nothing
if not variable.value:
return variable.value
return variable.value[1:]
case Operation.REMOVE_LAST:
# If array is empty, do nothing
if not variable.value:
return variable.value
return variable.value[:-1]
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)

View File

@ -0,0 +1,14 @@
"""
Repository interfaces for data access.
This package contains repository interfaces that define the contract
for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
__all__ = [
"OrderConfig",
"WorkflowNodeExecutionRepository",
]

View File

@ -0,0 +1,97 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution
@dataclass
class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
class WorkflowNodeExecutionRepository(Protocol):
"""
Repository interface for WorkflowNodeExecution.
This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to save
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
"""
...
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

View File

@ -39,7 +39,7 @@ class SubCondition(BaseModel):
class SubVariableCondition(BaseModel):
logical_operator: Literal["and", "or"]
conditions: list[SubCondition] = Field(default=list)
conditions: list[SubCondition] = Field(default_factory=list)
class Condition(BaseModel):

View File

@ -0,0 +1,639 @@
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
),
)
return response
else:
continue
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
"""
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception:
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
graph_runtime_state = None
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if response:
yield response
elif isinstance(event, QueueNodeStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_success_response:
yield node_success_response
elif isinstance(
event,
QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueLoopStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield loop_start_resp
elif isinstance(event, QueueLoopNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield loop_next_resp
elif isinstance(event, QueueLoopCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield loop_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
"""
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)
return response

View File

@ -0,0 +1,948 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
LoopNodeCompletedStreamResponse,
LoopNodeNextStreamResponse,
LoopNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
)
class WorkflowCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_node_execution_repository = workflow_node_execution_repository
def _handle_workflow_run_start(
self,
*,
session: Session,
workflow_id: str,
user_id: str,
created_by_role: CreatedByRole,
) -> WorkflowRun:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
)
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
continue
inputs[f"sys.{key.value}"] = value
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
workflow_run.tenant_id = workflow.tenant_id
workflow_run.app_id = workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = workflow.id
workflow_run.type = workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = workflow.version
workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = created_by_role
workflow_run.created_by = user_id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_run)
return workflow_run
def _handle_workflow_run_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_workflow_run_partial_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_workflow_run_failed(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id
)
# Update the cache with the retrieved executions
for execution in running_workflow_node_executions:
if execution.node_execution_id:
self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_node_execution_start(
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
# Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
merged_metadata = (
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
if event.execution_metadata is not None
else origin_metadata
)
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _workflow_start_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse:
_ = session
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse:
created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatedByRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
def _workflow_node_start_to_stream_response(
self,
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
return response
def _workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
)
def _workflow_node_retry_to_stream_response(
self,
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
_ = session
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
_ = session
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
_ = session
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
_ = session
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
_ = session
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
_ = session
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse:
_ = session
return LoopNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_loop_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse:
_ = session
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
:return:
"""
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file = self._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
# First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
# Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
),
)

View File

@ -9,6 +9,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
@ -364,4 +365,5 @@ class WorkflowEntry:
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
# append variable and value to variable pool
variable_pool.add([variable_node_id] + variable_key_list, input_value)
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
variable_pool.add([variable_node_id] + variable_key_list, input_value)