mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'main' into feat/r2
This commit is contained in:
@ -36,7 +36,7 @@ class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list, description="node configs mapping (node id: node config)"
|
||||
default_factory=dict, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
|
||||
@ -95,7 +95,12 @@ class StreamProcessor(ABC):
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
if node_id in reachable_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
|
||||
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
|
||||
@ -11,6 +11,7 @@ import docx
|
||||
import pandas as pd
|
||||
import pypandoc # type: ignore
|
||||
import pypdfium2 # type: ignore
|
||||
import webvtt # type: ignore
|
||||
import yaml # type: ignore
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
@ -132,6 +133,10 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
return _extract_text_from_json(file_content)
|
||||
case "application/x-yaml" | "text/yaml":
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case "text/vtt":
|
||||
return _extract_text_from_vtt(file_content)
|
||||
case "text/properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
@ -139,7 +144,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt":
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
@ -165,6 +170,10 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
return _extract_text_from_msg(file_content)
|
||||
case ".vtt":
|
||||
return _extract_text_from_vtt(file_content)
|
||||
case ".properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
|
||||
|
||||
@ -214,8 +223,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
|
||||
if not dify_config.UNSTRUCTURED_API_URL:
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
|
||||
@ -226,7 +235,7 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
@ -462,3 +471,68 @@ def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
|
||||
text = _extract_text_from_plain_text(vtt_bytes)
|
||||
|
||||
# remove bom
|
||||
text = text.lstrip("\ufeff")
|
||||
|
||||
raw_results = []
|
||||
for caption in webvtt.from_string(text):
|
||||
raw_results.append((caption.voice, caption.text))
|
||||
|
||||
# Merge consecutive utterances by the same speaker
|
||||
merged_results = []
|
||||
if raw_results:
|
||||
current_speaker, current_text = raw_results[0]
|
||||
|
||||
for i in range(1, len(raw_results)):
|
||||
spk, txt = raw_results[i]
|
||||
if spk == None:
|
||||
merged_results.append((None, current_text))
|
||||
continue
|
||||
|
||||
if spk == current_speaker:
|
||||
# If it is the same speaker, merge the utterances (joined by space)
|
||||
current_text += " " + txt
|
||||
else:
|
||||
# If the speaker changes, register the utterance so far and move on
|
||||
merged_results.append((current_speaker, current_text))
|
||||
current_speaker, current_text = spk, txt
|
||||
|
||||
# Add the last element
|
||||
merged_results.append((current_speaker, current_text))
|
||||
else:
|
||||
merged_results = raw_results
|
||||
|
||||
# Return the result in the specified format: Speaker "text" style
|
||||
formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
|
||||
return "\n".join(formatted)
|
||||
|
||||
|
||||
def _extract_text_from_properties(file_content: bytes) -> str:
|
||||
try:
|
||||
text = _extract_text_from_plain_text(file_content)
|
||||
lines = text.splitlines()
|
||||
result = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
# Preserve comments and empty lines
|
||||
if not line or line.startswith("#") or line.startswith("!"):
|
||||
result.append(line)
|
||||
continue
|
||||
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
elif ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
else:
|
||||
key, value = line, ""
|
||||
|
||||
result.append(f"{key.strip()}: {value.strip()}")
|
||||
|
||||
return "\n".join(result)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e
|
||||
|
||||
@ -262,7 +262,10 @@ class Executor:
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.auth.config.type == "basic":
|
||||
credentials = authorization.config.api_key
|
||||
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
||||
if ":" in credentials:
|
||||
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
||||
else:
|
||||
encoded_credentials = credentials
|
||||
headers[authorization.config.header] = f"Basic {encoded_credentials}"
|
||||
elif self.auth.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key or ""
|
||||
|
||||
@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
mime_type = (
|
||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
)
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
|
||||
@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
return event
|
||||
|
||||
iter_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
if parallel_mode_run_id:
|
||||
# for parallel, the specific branch ID is more important than the sequential index
|
||||
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
|
||||
if self.node_data.is_parallel
|
||||
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
|
||||
if self.node_data.is_parallel
|
||||
else iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
current_metadata = event.route_node_state.node_run_result.metadata or {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
return event
|
||||
|
||||
def _run_single_iter(
|
||||
|
||||
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import Integer, and_, func, or_, text
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
@ -32,11 +32,11 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
@ -264,6 +264,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
@ -275,12 +276,16 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
@ -289,7 +294,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"document_data_source_type": document.data_source_type,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
@ -356,12 +361,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
if node_data.metadata_filtering_conditions:
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
@ -372,13 +377,24 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
@ -493,24 +509,24 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
||||
case "≤" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
case _:
|
||||
pass
|
||||
return filters
|
||||
@ -618,7 +634,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
|
||||
@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
|
||||
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which company’s email address test@examp
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||
</example>
|
||||
</example>
|
||||
### User Input
|
||||
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||
### Assistant Output
|
||||
|
||||
@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||
class FileTypeNotSupportError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"{type_name} type is not supported by this model")
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
@ -0,0 +1,160 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||
LLM.
|
||||
"""
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
"""save_binary_string saves the inline file data returned by LLM.
|
||||
|
||||
Currently (2025-04-30), only some of Google Gemini models will return
|
||||
multimodal output as inline data.
|
||||
|
||||
:param data: the contents of the file
|
||||
:param mime_type: the media type of the file, specified by rfc6838
|
||||
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||
:param file_type: The file type of the inline file.
|
||||
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||
|
||||
The default value is `None`, which means do not override the file extension and guessing it
|
||||
from the `mime_type` attribute while saving the file.
|
||||
|
||||
Setting it to values other than `None` means override the file's extension, and
|
||||
will bypass the extension guessing saving the file.
|
||||
|
||||
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||
|
||||
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
||||
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||
"""get_extension return the extension of file.
|
||||
|
||||
If the `extension_override` parameter is set, this function should honor it and
|
||||
return its value.
|
||||
"""
|
||||
if extension_override is not None:
|
||||
return extension_override
|
||||
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||
|
||||
|
||||
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||
"""_extract_content_type_and_extension tries to
|
||||
guess content type of file from url and `Content-Type` header in response.
|
||||
"""
|
||||
if content_type_header:
|
||||
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||
return content_type_header, extension
|
||||
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||
return content_type, extension
|
||||
|
||||
|
||||
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||
# `extension_override` is allow to be `None or `""`.
|
||||
if extension_override is None:
|
||||
return None
|
||||
if extension_override == "":
|
||||
return ""
|
||||
if not extension_override.startswith("."):
|
||||
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||
return extension_override
|
||||
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -95,9 +96,13 @@ from .exc import (
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,8 +111,45 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||
return None
|
||||
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = self._file_outputs
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("error while executing llm node")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=message_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
yield event
|
||||
return
|
||||
|
||||
model = None
|
||||
# For streaming mode
|
||||
model = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ""
|
||||
usage = None
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
for result in invoke_result:
|
||||
text = convert_llm_result_chunk_to_str(result.delta.message.content)
|
||||
full_text += text
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
if not model:
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not finish_reason and result.delta.finish_reason:
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
def _image_file_to_markdown(self, file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
@ -459,7 +506,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("document_data_source_type"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
@ -750,18 +797,22 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||
structured_output: dict[str, Any] | list[Any] = {}
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
try:
|
||||
parsed = json.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
if not isinstance(parsed, dict):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except json.JSONDecodeError as e:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
parsed = json_repair.loads(result_text)
|
||||
if not isinstance(parsed, (dict | list)):
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
if not isinstance(parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(parsed, list):
|
||||
parsed = next((item for item in parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
return structured_output
|
||||
|
||||
@ -963,6 +1014,42 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
|
||||
buffer.write(text_part)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
|
||||
- Inlined data encoded in base64, which would be saved to storage directly.
|
||||
- Remote files referenced by an url, which would be downloaded and then saved to storage.
|
||||
|
||||
Currently, only image files are supported.
|
||||
"""
|
||||
# Inject the saver somehow...
|
||||
_saver = self._llm_file_saver
|
||||
|
||||
# If this
|
||||
if content.url != "":
|
||||
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
else:
|
||||
saved_file = _saver.save_binary_string(
|
||||
data=base64.b64decode(content.base64_data),
|
||||
mime_type=content.mime_type,
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
@ -1123,6 +1210,41 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
If the messages contain non-textual content (e.g., multimedia like images or videos),
|
||||
it will be saved separately, and the corresponding Markdown representation will
|
||||
be yielded to the caller.
|
||||
"""
|
||||
|
||||
# NOTE(QuantumGhost): This function should yield results to the caller immediately
|
||||
# whenever new content or partial content is available. Avoid any intermediate buffering
|
||||
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
|
||||
# if necessary.
|
||||
if contents is None:
|
||||
yield from []
|
||||
return
|
||||
if isinstance(contents, str):
|
||||
yield contents
|
||||
elif isinstance(contents, list):
|
||||
for item in contents:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
elif isinstance(item, ImagePromptMessageContent):
|
||||
file = self._save_multimodal_image_output(item)
|
||||
self._file_outputs.append(file)
|
||||
yield self._image_file_to_markdown(file)
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
else:
|
||||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
|
||||
@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
|
||||
@ -337,7 +337,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
return {"check_break_result": True}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
|
||||
@ -17,7 +17,7 @@ Some additional information is provided below. Always adhere to these instructio
|
||||
</instruction>
|
||||
Steps:
|
||||
1. Review the chat history provided within the <histories> tags.
|
||||
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
|
||||
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
|
||||
3. Generate a well-formatted output using the defined functions and arguments.
|
||||
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
|
||||
5. Do not include any XML tags in your output.
|
||||
@ -89,13 +89,13 @@ Some extra information are provided below, I should always follow the instructio
|
||||
</instructions>
|
||||
|
||||
### Extract parameter Workflow
|
||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||
<information to be extracted>
|
||||
{{ structure }}
|
||||
</information to be extracted>
|
||||
|
||||
Step 1: Carefully read the input and understand the structure of the expected output.
|
||||
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
|
||||
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
|
||||
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
|
||||
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
|
||||
|
||||
@ -106,10 +106,10 @@ Here are the chat histories between human and assistant, inside <histories></his
|
||||
</histories>
|
||||
|
||||
### Structure
|
||||
Here is the structure of the expected output, I should always follow the output structure.
|
||||
Here is the structure of the expected output, I should always follow the output structure.
|
||||
{{γγγ
|
||||
'properties1': 'relevant text extracted from input',
|
||||
'properties2': 'relevant text extracted from input',
|
||||
'properties1': 'relevant text extracted from input',
|
||||
'properties2': 'relevant text extracted from input',
|
||||
}}γγγ
|
||||
|
||||
### Input Text
|
||||
@ -119,7 +119,7 @@ Inside <text></text> XML tags, there is a text that I should extract parameters
|
||||
</text>
|
||||
|
||||
### Answer
|
||||
I should always output a valid JSON object. Output nothing other than the JSON object.
|
||||
I should always output a valid JSON object. Output nothing other than the JSON object.
|
||||
```JSON
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ You are a text classification engine that analyzes text data and assigns categor
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
@ -64,7 +64,7 @@ User:{{"input_text": ["I recently had a great experience with your company. The
|
||||
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
|
||||
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
|
||||
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
|
||||
</example>
|
||||
</example>
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
|
||||
@ -11,6 +11,8 @@ class Operation(StrEnum):
|
||||
SUBTRACT = "-="
|
||||
MULTIPLY = "*="
|
||||
DIVIDE = "/="
|
||||
REMOVE_FIRST = "remove-first"
|
||||
REMOVE_LAST = "remove-last"
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
|
||||
@ -23,6 +23,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -51,7 +60,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
|
||||
|
||||
|
||||
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
|
||||
if operation == Operation.CLEAR:
|
||||
if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}:
|
||||
return True
|
||||
match variable_type:
|
||||
case SegmentType.STRING:
|
||||
|
||||
@ -64,7 +64,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
# Get value from variable pool
|
||||
if (
|
||||
item.input_type == InputType.VARIABLE
|
||||
and item.operation != Operation.CLEAR
|
||||
and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}
|
||||
and item.value is not None
|
||||
):
|
||||
value = self.graph_runtime_state.variable_pool.get(item.value)
|
||||
@ -165,5 +165,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
return variable.value * value
|
||||
case Operation.DIVIDE:
|
||||
return variable.value / value
|
||||
case Operation.REMOVE_FIRST:
|
||||
# If array is empty, do nothing
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[1:]
|
||||
case Operation.REMOVE_LAST:
|
||||
# If array is empty, do nothing
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[:-1]
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)
|
||||
|
||||
14
api/core/workflow/repository/__init__.py
Normal file
14
api/core/workflow/repository/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Repository interfaces for data access.
|
||||
|
||||
This package contains repository interfaces that define the contract
|
||||
for accessing and manipulating data, regardless of the underlying
|
||||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"OrderConfig",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,97 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Repository interface for WorkflowNodeExecution.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
"""
|
||||
...
|
||||
@ -39,7 +39,7 @@ class SubCondition(BaseModel):
|
||||
|
||||
class SubVariableCondition(BaseModel):
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[SubCondition] = Field(default=list)
|
||||
conditions: list[SubCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
|
||||
639
api/core/workflow/workflow_app_generate_task_pipeline.py
Normal file
639
api/core/workflow/workflow_app_generate_task_pipeline.py
Normal file
@ -0,0 +1,639 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if isinstance(user, EndUser):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatedByRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatedByRole.ACCOUNT
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
workflow_run_id = None
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow_features_dict
|
||||
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception:
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
):
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_start_resp
|
||||
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_next_resp
|
||||
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_finish_resp
|
||||
|
||||
elif isinstance(event, QueueLoopStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_start_resp
|
||||
|
||||
elif isinstance(event, QueueLoopNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_next_resp
|
||||
|
||||
elif isinstance(event, QueueLoopCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_finish_resp
|
||||
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
:return:
|
||||
"""
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
elif invoke_from == InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
elif invoke_from == InvokeFrom.WEB_APP:
|
||||
created_from = WorkflowAppLogCreatedFrom.WEB_APP
|
||||
else:
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
|
||||
session.add(workflow_app_log)
|
||||
session.commit()
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
)
|
||||
|
||||
return response
|
||||
948
api/core/workflow/workflow_cycle_manager.py
Normal file
948
api/core/workflow/workflow_cycle_manager.py
Normal file
@ -0,0 +1,948 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
LoopNodeCompletedStreamResponse,
|
||||
LoopNodeNextStreamResponse,
|
||||
LoopNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
created_by_role: CreatedByRole,
|
||||
) -> WorkflowRun:
|
||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.scalar(workflow_stmt)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
|
||||
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
|
||||
WorkflowRun.tenant_id == workflow.tenant_id,
|
||||
WorkflowRun.app_id == workflow.app_id,
|
||||
)
|
||||
max_sequence = session.scalar(max_sequence_stmt) or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# handle special values
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = workflow.tenant_id
|
||||
workflow_run.app_id = workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = workflow.id
|
||||
workflow_run.type = workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = workflow.version
|
||||
workflow_run.graph = workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = created_by_role
|
||||
workflow_run.created_by = user_id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run_id: workflow run id
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run_id: workflow run id
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||
workflow_run_id=workflow_run.id
|
||||
)
|
||||
|
||||
# Update the cache with the retrieved executions
|
||||
for execution in running_workflow_node_executions:
|
||||
if execution.node_execution_id:
|
||||
self._workflow_node_executions[execution.node_execution_id] = execution
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(
|
||||
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
# Use the instance repository to save the workflow node execution
|
||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
# Use the instance repository to update the workflow node execution
|
||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
)
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_run: workflow run
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
# Use the instance repository to save the workflow node execution
|
||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
_ = session
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=dict(workflow_run.inputs_dict or {}),
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatedByRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
_ = session
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
_ = session
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
_ = session
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_next_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||
) -> IterationNodeNextStreamResponse:
|
||||
_ = session
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_completed_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
_ = session
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
||||
) -> LoopNodeStartStreamResponse:
|
||||
_ = session
|
||||
return LoopNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_next_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
||||
) -> LoopNodeNextStreamResponse:
|
||||
_ = session
|
||||
return LoopNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_completed_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
||||
) -> LoopNodeCompletedStreamResponse:
|
||||
_ = session
|
||||
return LoopNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
:return:
|
||||
"""
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
# Flatten the list of sequences into a single list of mappings
|
||||
flattened_files = [file for sublist in files if sublist for file in sublist]
|
||||
|
||||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file = self._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file = self._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
# First check the cache for performance
|
||||
if node_execution_id in self._workflow_node_executions:
|
||||
cached_execution = self._workflow_node_executions[node_execution_id]
|
||||
# No need to merge with session since expire_on_commit=False
|
||||
return cached_execution
|
||||
|
||||
# If not in cache, use the instance repository to get by node_execution_id
|
||||
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
|
||||
|
||||
if not execution:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
# Update cache
|
||||
self._workflow_node_executions[node_execution_id] = execution
|
||||
return execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
:param task_id: task id
|
||||
:param event: agent log event
|
||||
:return:
|
||||
"""
|
||||
return AgentLogStreamResponse(
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
),
|
||||
)
|
||||
@ -9,6 +9,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
@ -364,4 +365,5 @@ class WorkflowEntry:
|
||||
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
Reference in New Issue
Block a user