mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 19:27:40 +08:00
Merge remote-tracking branch 'origin/main' into feat/support-agent-sandbox
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -300,6 +302,7 @@ class BaseAgentRunner(AppRunner):
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
tool_process_data=None,
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str="{}",
|
||||
@ -307,20 +310,20 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_unit_price=Decimal(0),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
answer_unit_price=Decimal(0),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
total_price=Decimal(0),
|
||||
position=self.agent_thought_count + 1,
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
@ -353,7 +356,8 @@ class BaseAgentRunner(AppRunner):
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
if thought:
|
||||
agent_thought.thought += thought
|
||||
existing_thought = agent_thought.thought or ""
|
||||
agent_thought.thought = f"{existing_thought}{thought}"
|
||||
|
||||
if tool_name:
|
||||
agent_thought.tool = tool_name
|
||||
@ -451,21 +455,30 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(";")
|
||||
tool_names_raw = agent_thought.tool
|
||||
if tool_names_raw:
|
||||
tool_names = tool_names_raw.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
tool_input_payload = agent_thought.tool_input
|
||||
if tool_input_payload:
|
||||
try:
|
||||
tool_inputs = json.loads(tool_input_payload)
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
else:
|
||||
tool_inputs = {tool: {} for tool in tool_names}
|
||||
|
||||
for tool in tools:
|
||||
observation_payload = agent_thought.observation
|
||||
if observation_payload:
|
||||
try:
|
||||
tool_responses = json.loads(observation_payload)
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
else:
|
||||
tool_responses = dict.fromkeys(tool_names, observation_payload)
|
||||
|
||||
for tool in tool_names:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
@ -495,7 +508,7 @@ class BaseAgentRunner(AppRunner):
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
if not tool_names_raw:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
),
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class SpanBuilder:
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
kind=span_data.span_kind,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -34,3 +34,4 @@ class SpanData(BaseModel):
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
|
||||
|
||||
@ -212,6 +212,10 @@ class WorkflowExecutionStatus(StrEnum):
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
@classmethod
|
||||
def ended_values(cls) -> list[str]:
|
||||
return [status.value for status in _END_STATE]
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
|
||||
@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
||||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
assigned_selector = tuple(self.node_data.assigned_variable_selector)
|
||||
return assigned_selector in variable_selectors
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@ -136,13 +137,11 @@ class WorkflowEntry:
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config = dict(workflow.get_node_config_by_id(node_id))
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -158,12 +157,12 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
||||
Reference in New Issue
Block a user