Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-12-27 17:33:58 +08:00
416 changed files with 14790 additions and 5027 deletions

View File

@ -613,10 +613,10 @@ class Graph(BaseModel):
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids:
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id2 in merge_branch_node_ids:
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}

View File

@ -48,9 +48,11 @@ class StreamProcessor(ABC):
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
if "answer" in ids:
continue
else:
reachable_node_ids.extend(ids)
else:
unreachable_first_node_ids.append(edge.target_node_id)

View File

@ -20,3 +20,7 @@ class ResponseSizeError(HttpRequestNodeError):
class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid."""
class InvalidURLError(HttpRequestNodeError):
"""Raised when the URL is invalid."""

View File

@ -23,6 +23,7 @@ from .exc import (
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
InvalidURLError,
RequestBodyError,
ResponseSizeError,
)
@ -66,6 +67,12 @@ class Executor:
node_data.authorization.config.api_key
).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url
self.method = node_data.method
self.auth = node_data.authorization

View File

@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData
@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"document_data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None),
"score": record.score or 0.0,
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
position += 1
return retrieval_resource_list
@classmethod

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
@ -189,10 +189,12 @@ class ToolNode(BaseNode[ToolNodeData]):
conversation_id=None,
)
files: list[File] = []
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLog] = []
variables: dict[str, Any] = {}
for message in message_stream:
@ -239,14 +241,16 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
File(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
extension=None,
mime_type=message.meta.get("mime_type", "application/octet-stream"),
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT: