Merge branch 'main' into feat/mcp

This commit is contained in:
Novice
2025-07-09 09:41:42 +08:00
234 changed files with 8742 additions and 1254 deletions

View File

@ -27,6 +27,9 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
@ -36,7 +39,10 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
)
logger = logging.getLogger(__name__)
@ -450,6 +456,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -521,6 +528,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -547,6 +555,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)
try:

View File

@ -64,6 +64,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, W
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline:
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -153,6 +155,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -371,6 +374,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_node_execution=workflow_node_execution,
)
session.commit()
self._save_output_for_event(event, workflow_node_execution.id)
if node_finish_resp:
yield node_finish_resp
@ -390,6 +394,8 @@ class AdvancedChatAppGenerateTaskPipeline:
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_finish_resp:
yield node_finish_resp
@ -759,3 +765,15 @@ class AdvancedChatAppGenerateTaskPipeline:
if not message:
raise ValueError(f"Message not found: {self._message_id}")
return message
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@ -1,10 +1,20 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from factories import file_factory
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
@ -159,3 +169,38 @@ class BaseAppGenerator:
yield f"event: {message}\n\n"
return gen()
@final
@staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return DraftVariableSaverImpl(
session=session,
app_id=app_id,
node_id=node_id,
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
)
else:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return NoopDraftVariableSaver()
return draft_var_saver_factory

View File

@ -44,6 +44,7 @@ from core.app.entities.task_entities import (
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
@ -506,7 +507,8 @@ class WorkflowResponseConverter:
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
@classmethod
def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
@ -515,20 +517,30 @@ class WorkflowResponseConverter:
if not value:
return []
files = []
if isinstance(value, list):
files: list[Mapping[str, Any]] = []
if isinstance(value, FileSegment):
files.append(value.value.to_dict())
elif isinstance(value, ArrayFileSegment):
files.extend([i.to_dict() for i in value.value])
elif isinstance(value, File):
files.append(value.to_dict())
elif isinstance(value, list):
for item in value:
file = self._get_file_var_from_value(item)
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file = self._get_file_var_from_value(value)
elif isinstance(
value,
dict,
):
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
@classmethod
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value

View File

@ -25,6 +25,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
@ -236,6 +237,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
@ -244,6 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@ -474,6 +480,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -494,6 +501,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)

View File

@ -56,6 +56,7 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
@ -87,6 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -131,6 +133,8 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -322,6 +326,8 @@ class WorkflowAppGenerateTaskPipeline:
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
if node_success_response:
yield node_success_response
elif isinstance(
@ -339,6 +345,8 @@ class WorkflowAppGenerateTaskPipeline:
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_failed_response:
yield node_failed_response
@ -593,3 +601,15 @@ class WorkflowAppGenerateTaskPipeline:
)
return response
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@ -1,8 +1,6 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
@ -35,7 +33,6 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
BaseNodeEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
@ -70,9 +67,6 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
from services.workflow_draft_variable_service import (
DraftVariableSaver,
)
class WorkflowBasedAppRunner(AppRunner):
@ -400,7 +394,6 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
self._save_draft_var_for_event(event)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
@ -464,7 +457,6 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
self._save_draft_var_for_event(event)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
@ -718,30 +710,3 @@ class WorkflowBasedAppRunner(AppRunner):
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
def _save_draft_var_for_event(self, event: BaseNodeEvent):
run_result = event.route_node_state.node_run_result
if run_result is None:
return
process_data = run_result.process_data
outputs = run_result.outputs
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=self._get_app_id(),
node_id=event.node_id,
node_type=event.node_type,
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
invoke_from=self.queue_manager._invoke_from,
node_execution_id=event.id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
)
draft_var_saver.save(process_data=process_data, outputs=outputs)
def _remove_first_element_from_variable_string(key: str) -> str:
"""
Remove the first element from the prefix.
"""
prefix, remaining = key.split(".", maxsplit=1)
return remaining

View File

@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from models.enums import MessageStatus
from models.model import Message
logger = logging.getLogger(__name__)
@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
return err
err_desc = self._error_to_desc(err)
message.status = "error"
message.status = MessageStatus.ERROR
message.error = err_desc
return err

View File

@ -395,6 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

View File

@ -51,7 +51,7 @@ class File(BaseModel):
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
mime_type: Optional[str] = None
size: int = -1

View File

@ -1,67 +0,0 @@
import base64
import logging
import time
from typing import Optional
from configs import dify_config
from constants import IMAGE_EXTENSIONS
from core.helper.url_signer import UrlSigner
from extensions.ext_storage import storage
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.exception(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file_id: the id of UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
result = UrlSigner.verify(
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
)
# verify signature
if not result:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@ -28,7 +28,7 @@ class TemplateTransformer(ABC):
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
return result.group(1)
@classmethod
@ -38,16 +38,53 @@ class TemplateTransformer(ABC):
:param response: response
:return:
"""
try:
result = json.loads(cls.extract_result_str_from_response(response))
except json.JSONDecodeError:
raise ValueError("failed to parse response")
result_str = cls.extract_result_str_from_response(response)
result = json.loads(result_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
except ValueError as e:
# Re-raise ValueError from extract_result_str_from_response
raise e
except Exception as e:
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
# Check if the result contains an error
if isinstance(result, dict) and "error" in result:
raise ValueError(f"JavaScript execution error: {result['error']}")
if not isinstance(result, dict):
raise ValueError("result must be a dict")
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
if not all(isinstance(k, str) for k in result):
raise ValueError("result keys must be strings")
raise ValueError("Result keys must be strings")
# Post-process the result to convert scientific notation strings back to numbers
result = cls._post_process_result(result)
return result
@classmethod
def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
"""
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
try:
return float(value)
except ValueError:
pass
elif isinstance(value, dict):
return {k: convert_scientific_notation(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
@classmethod
@abstractmethod
def get_runner_script(cls) -> str:

View File

@ -1,22 +0,0 @@
from collections import OrderedDict
from typing import Any
class LRUCache:
def __init__(self, capacity: int):
self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
else:
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
return self.cache[key]
def put(self, key: Any, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False) # pop the first item

View File

@ -317,9 +317,10 @@ class IndexingRunner:
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
if image_file:
storage.delete(image_file.key)
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \

View File

@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
@ -170,10 +171,15 @@ def invoke_llm_with_structured_output(
system_fingerprint: Optional[str] = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
elif isinstance(event.delta.message.content, list):
for item in event.delta.message.content:
if isinstance(item, TextPromptMessageContent):
result_text += item.data
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,

View File

@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
latency=0.0,
)
@classmethod
def from_metadata(cls, metadata: dict) -> "LLMUsage":
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: Dictionary containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
total_tokens = metadata.get("total_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
)
def plus(self, other: "LLMUsage") -> "LLMUsage":
"""
Add two LLMUsage instances together.

View File

View File

@ -0,0 +1,487 @@
import json
import logging
from collections.abc import Sequence
from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
convert_datetime_to_nanoseconds,
convert_to_span_id,
convert_to_trace_id,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_SYSTEM,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes import NodeType
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
logger = logging.getLogger(__name__)
class AliyunDataTrace(BaseTraceInstance):
def __init__(
self,
aliyun_config: AliyunConfig,
):
super().__init__(aliyun_config)
base_url = aliyun_config.endpoint.rstrip("/")
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
pass
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self):
return self.trace_client.api_check()
def get_project_url(self):
try:
return self.trace_client.get_project_url()
except Exception as e:
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Aliyun get run url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
self.trace_client.add_span(node_span)
def message_trace(self, trace_info: MessageTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
message_id = trace_info.message_id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(message_span)
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
pre_prompt = getattr(app_model_config, "pre_prompt", "")
inputs_data = getattr(trace_info.message_data, "inputs", {})
llm_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=convert_to_span_id(message_id, "llm"),
name="llm",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: str(trace_info.outputs),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(llm_span)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name="dataset_retrieval",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
)
self.trace_client.add_span(dataset_retrieval_span)
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
tool_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name=trace_info.tool_name,
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
self.trace_client.add_span(tool_span)
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
return workflow_node_executions
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
):
try:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
except Exception:
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
span_status: Status = Status(StatusCode.UNSET)
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
span_status = Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
span_status = Status(StatusCode.ERROR, str(node_execution.error))
return span_status
def build_workflow_task_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_tool_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_retrieval_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_llm_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=self.get_workflow_node_status(node_execution),
)
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
user_id = trace_info.metadata.get("user_id")
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id: # chatflow
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(message_span)
workflow_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(workflow_span)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
suggested_question_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=convert_to_span_id(message_id, "suggested_question"),
name="suggested_question",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(suggested_question_span)
def extract_retrieval_documents(documents: list[Document]):
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@ -0,0 +1,200 @@
import hashlib
import logging
import random
import socket
import threading
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
import requests
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
logger = logging.getLogger(__name__)
class TraceClient:
def __init__(
self,
service_name: str,
endpoint: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
):
self.endpoint = endpoint
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
self.span_builder = SpanBuilder(self.resource)
self.exporter = OTLPSpanExporter(endpoint=endpoint)
self.max_queue_size = max_queue_size
self.schedule_delay_sec = schedule_delay_sec
self.max_export_batch_size = max_export_batch_size
self.queue: deque = deque(maxlen=max_queue_size)
self.condition = threading.Condition(threading.Lock())
self.done = False
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._spans_dropped = False
def export(self, spans: Sequence[ReadableSpan]):
self.exporter.export(spans)
def api_check(self):
try:
response = requests.head(self.endpoint, timeout=5)
if response.status_code == 405:
return True
else:
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.debug(f"AliyunTrace API check failed: {str(e)}")
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self):
return "https://arms.console.aliyun.com/#/llm"
def add_span(self, span_data: SpanData):
if span_data is None:
return
span: ReadableSpan = self.span_builder.build_span(span_data)
with self.condition:
if len(self.queue) == self.max_queue_size:
if not self._spans_dropped:
logger.warning("Queue is full, likely spans will be dropped.")
self._spans_dropped = True
self.queue.appendleft(span)
if len(self.queue) >= self.max_export_batch_size:
self.condition.notify()
def _worker(self):
while not self.done:
with self.condition:
if len(self.queue) < self.max_export_batch_size and not self.done:
self.condition.wait(timeout=self.schedule_delay_sec)
self._export_batch()
def _export_batch(self):
spans_to_export: list[ReadableSpan] = []
with self.condition:
while len(spans_to_export) < self.max_export_batch_size and self.queue:
spans_to_export.append(self.queue.pop())
if spans_to_export:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug(f"Error exporting spans: {e}")
def shutdown(self):
with self.condition:
self.done = True
self.condition.notify_all()
self.worker_thread.join()
self._export_batch()
self.exporter.shutdown()
class SpanBuilder:
def __init__(self, resource):
self.resource = resource
self.instrumentation_scope = InstrumentationScope(
__name__,
"",
None,
None,
)
def build_span(self, span_data: SpanData) -> ReadableSpan:
span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
parent_span_context = None
if span_data.parent_span_id is not None:
parent_span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.parent_span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
span = ReadableSpan(
name=span_data.name,
context=span_context,
parent=parent_span_context,
resource=self.resource,
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
instrumentation_scope=self.instrumentation_scope,
)
return span
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
return span_id
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
if start_time_a is None:
return None
timestamp_in_seconds = start_time_a.timestamp()
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
return timestamp_in_nanoseconds

View File

@ -0,0 +1,21 @@
from collections.abc import Sequence
from typing import Optional
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")

View File

@ -0,0 +1,64 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
# Chain
INPUT_VALUE = "input.value"
OUTPUT_VALUE = "output.value"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# LLM
GEN_AI_MODEL_NAME = "gen_ai.model_name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
class GenAISpanKind(Enum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"
LLM = "LLM"
EMBEDDING = "EMBEDDING"
TOOL = "TOOL"
AGENT = "AGENT"
TASK = "TASK"

View File

@ -0,0 +1,726 @@
import hashlib
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
try:
# Choose the appropriate exporter based on config type
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
if isinstance(arize_phoenix_config, ArizeConfig):
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
arize_headers = {
"api_key": arize_phoenix_config.api_key or "",
"space_id": arize_phoenix_config.space_id or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = GrpcOTLPSpanExporter(
endpoint=arize_endpoint,
headers=arize_headers,
timeout=30,
)
else:
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
phoenix_headers = {
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=30,
)
attributes = {
"openinference.project.name": arize_phoenix_config.project or "",
"model_id": arize_phoenix_config.project or "",
}
resource = Resource(attributes=attributes)
provider = trace_sdk.TracerProvider(resource=resource)
processor = SimpleSpanProcessor(
exporter,
)
provider.add_span_processor(processor)
# Create a named tracer instead of setting the global provider
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
except Exception as e:
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
raise
def datetime_to_nanos(dt: Optional[datetime]) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def uuid_to_trace_id(string: Optional[str]) -> int:
"""Convert UUID string to a valid trace ID (16-byte integer)."""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
# Take the first 16 bytes (128 bits) of the hash
digest = hash_object.digest()[:16]
# Convert to integer (128 bits)
return int.from_bytes(digest, byteorder="big")
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
if trace_info.message_data is None:
return
workflow_metadata = {
"workflow_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_metadata["ls_provider"] = provider
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
)
try:
if node_execution.node_type == "llm":
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
if model:
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
return
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
trace_id = uuid_to_trace_id(trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
if trace_info.message_data.model_id is not None:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
if trace_info.message_data.model_provider is not None:
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
trace_id = uuid_to_trace_id(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
span.set_attribute("test", "true")
return True
except Exception as e:
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
except Exception as e:
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes

View File

@ -2,20 +2,92 @@ from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum):
ARIZE = "arize"
PHOENIX = "phoenix"
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
class BaseTracingConfig(BaseModel):
"""
Base model class for tracing
Base model class for tracing configurations
"""
...
@classmethod
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
"""
Common endpoint URL validation logic
Args:
v: URL value to validate
default_url: Default URL to use if input is None or empty
Returns:
Validated and normalized URL
"""
return validate_url(v, default_url)
@classmethod
def validate_project_field(cls, v: str, default_name: str) -> str:
"""
Common project name validation logic
Args:
v: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Validated project name
"""
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
@field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.langfuse.com"
if not v.startswith("https://") and not v.startswith("http://"):
raise ValueError("host must start with https:// or http://")
return v
def host_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.smith.langchain.com"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://trace.wandb.ai"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def validate_host(cls, v, info: ValidationInfo):
if v is not None and v != "":
if not v.startswith(("https://", "http://")):
raise ValueError("host must start with https:// or http://")
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -32,6 +32,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@ -180,12 +181,9 @@ class LangFuseDataTrace(BaseTraceInstance):
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@ -293,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
input=trace_info.inputs,
output=message_data.answer,
metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
@ -339,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)

View File

@ -206,12 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)

View File

@ -222,10 +222,10 @@ class OpikDataTrace(BaseTraceInstance):
)
try:
if outputs.get("usage"):
total_tokens = outputs["usage"].get("total_tokens", 0)
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
completion_tokens = outputs["usage"].get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
total_tokens = usage_data.get("total_tokens", 0)
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)

View File

@ -84,6 +84,36 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")

View File

@ -1,6 +1,7 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
from extensions.ext_database import db
from models.model import Message
@ -60,3 +61,83 @@ def generate_dotted_order(
return current_segment
return f"{parent_dotted_order}.{current_segment}"
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
"""
Validate and normalize URL with proper error handling
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
allowed_schemes: Tuple of allowed URL schemes (default: https, http)
Returns:
Normalized URL string
Raises:
ValueError: If URL format is invalid or scheme not allowed
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in allowed_schemes:
raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}")
# Reconstruct URL with only scheme, netloc (removing path, query, fragment)
normalized_url = f"{parsed.scheme}://{parsed.netloc}"
return normalized_url
def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str:
"""
Validate URL that may include path components
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
required_suffix: Optional suffix that URL must end with
Returns:
Validated URL string
Raises:
ValueError: If URL format is invalid or doesn't match required suffix
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in ("https", "http"):
raise ValueError("URL must start with https:// or http://")
# Check required suffix if specified
if required_suffix and not url.endswith(required_suffix):
raise ValueError(f"URL should end with {required_suffix}")
return url
def validate_project_name(project: str, default_name: str) -> str:
"""
Validate and normalize project name
Args:
project: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Normalized project name
"""
if not project or project.strip() == "":
return default_name
return project.strip()

View File

@ -1,7 +1,6 @@
"""Document loader helpers."""
import concurrent.futures
from pathlib import Path
from typing import NamedTuple, Optional, cast
@ -16,7 +15,7 @@ class FileEncoding(NamedTuple):
"""The language of the file."""
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1024 * 1024) -> list[FileEncoding]:
"""Try to detect the file encoding.
Returns a list of `FileEncoding` tuples with the detected encodings ordered
@ -25,11 +24,16 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
Args:
file_path: The path to the file to detect the encoding for.
timeout: The timeout in seconds for the encoding detection.
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
For large files, reading only a sample is sufficient and prevents timeout.
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
rawdata = Path(file_path).read_bytes()
with open(file_path, "rb") as f:
# Read only a sample of the file for encoding detection
# This prevents timeout on large files while still providing accurate encoding detection
rawdata = f.read(sample_size)
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:

View File

@ -36,8 +36,12 @@ class TextExtractor(BaseExtractor):
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(
f"Decode failed: {self._file_path}, all detected encodings failed. Original error: {e}"
)
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
raise RuntimeError(f"Decode failed: {self._file_path}, specified encoding failed. Original error: {e}")
except Exception as e:
raise RuntimeError(f"Error loading {self._file_path}") from e

View File

@ -1010,6 +1010,9 @@ class DatasetRetrieval:
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
if value is None:
return
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:

View File

@ -31,6 +31,14 @@ class TTSTool(BuiltinTool):
model_type=ModelType.TTS,
model=model,
)
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,

View File

@ -66,11 +66,21 @@ class WorkflowNodeExecution(BaseModel):
but they are not stored in the model.
"""
# Core identification fields
id: str # Unique identifier for this execution record
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
# --------- Core identification fields ---------
# Unique identifier for this execution record, used when persisting to storage.
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
id: str
# Optional secondary ID for cross-referencing purposes.
#
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
# In most scenarios, `id` should be used as the primary identifier.
node_execution_id: Optional[str] = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
# --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization

View File

@ -103,7 +103,7 @@ class GraphEngine:
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
@ -140,7 +140,7 @@ class GraphEngine:
call_depth=call_depth,
)
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time

View File

@ -1,4 +1,5 @@
import json
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
@ -15,7 +16,7 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
@ -106,6 +107,32 @@ class AgentNode(ToolNode):
try:
# convert tool messages
agent_thoughts: list = []
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message(
message_stream,
@ -114,6 +141,7 @@ class AgentNode(ToolNode):
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
parameters_for_log,
agent_thoughts,
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(

View File

@ -2,7 +2,6 @@ import logging
from collections.abc import Generator
from typing import cast
from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -201,44 +200,3 @@ class AnswerStreamProcessor(StreamProcessor):
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
@classmethod
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file_var = cls._get_file_var_from_value(item)
if file_var:
files.append(file_var)
elif isinstance(value, dict):
file_var = cls._get_file_var_from_value(value)
if file_var:
files.append(file_var)
return files
@classmethod
def _get_file_var_from_value(cls, value: dict | list):
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict):
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None

View File

@ -8,6 +8,7 @@ from typing import Any, Literal
from urllib.parse import urlencode, urlparse
import httpx
from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
@ -178,7 +179,8 @@ class Executor:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
try:
json_object = json.loads(json_string, strict=False)
repaired = repair_json(json_string)
json_object = json.loads(repaired, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object
@ -333,7 +335,7 @@ class Executor:
try:
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore

View File

@ -1,5 +1,6 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
@ -133,8 +134,11 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -146,7 +150,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,

View File

@ -490,6 +490,9 @@ class KnowledgeRetrievalNode(LLMNode):
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
if value is None:
return
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:

View File

@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]):
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# handle invoke result
generator = self._invoke_llm(
node_data_model=self.node_data.model,
@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]):
elif isinstance(event, LLMStructuredOutput):
structured_output = event
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if structured_output:
outputs["structured_output"] = structured_output.structured_output

View File

@ -1,5 +1,6 @@
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast
@ -101,8 +102,11 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -114,7 +118,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,

View File

@ -253,7 +253,12 @@ class ParameterExtractorNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
outputs={
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
**result,
},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,

View File

@ -145,7 +145,11 @@ class QuestionClassifierNode(LLMNode):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {"class_name": category_name, "class_id": category_id}
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -1,11 +1,12 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -190,6 +191,7 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@ -208,7 +210,7 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
@ -276,9 +278,10 @@ class ToolNode(BaseNode[ToolNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
key: value
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
@ -366,17 +369,42 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: dict[str, Any] = {}
if json:
if isinstance(json, list) and len(json) == 1:
# If json is a list with only one element, convert it to a dictionary
json_output = json[0] if isinstance(json[0], dict) else {"data": json[0]}
elif isinstance(json, list):
# If json is a list with multiple elements, create a dictionary containing all data
json_output = {"data": json}
if agent_logs:
# Add agent_logs to json output
json_output["agent_logs"] = [
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
for log in agent_logs
]
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@ -0,0 +1,32 @@
import abc
from collections.abc import Mapping
from typing import Any, Protocol
from sqlalchemy.orm import Session
from core.workflow.nodes.enums import NodeType
class DraftVariableSaver(Protocol):
@abc.abstractmethod
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
pass
class DraftVariableSaverFactory(Protocol):
@abc.abstractmethod
def __call__(
self,
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> "DraftVariableSaver":
pass
class NoopDraftVariableSaver(DraftVariableSaver):
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
pass

View File

@ -27,6 +27,7 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@dataclass
@ -160,12 +161,13 @@ class WorkflowCycleManager:
exceptions_count: int = 0,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
workflow_execution.status = WorkflowExecutionStatus(status.value)
workflow_execution.error_message = error_message
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_execution.finished_at = now
workflow_execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
@ -174,7 +176,6 @@ class WorkflowCycleManager:
)
# Update the domain models
now = datetime.now(UTC).replace(tzinfo=None)
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model

View File

@ -69,6 +69,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
@ -80,7 +81,7 @@ class WorkflowEntry:
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id,