mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor(api): move trace providers (#35144)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,547 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
extract_retrieval_documents,
|
||||
format_input_messages,
|
||||
format_output_messages,
|
||||
format_retrieval_documents,
|
||||
get_user_id_from_message_data,
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get project url failed: {str(e)}")
|
||||
|
||||
def _extract_app_id(self, trace_info: BaseTraceInfo) -> str:
|
||||
"""Extract app_id from trace_info, trying metadata first then message_data."""
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if app_id:
|
||||
return str(app_id)
|
||||
message_data = getattr(trace_info, "message_data", None)
|
||||
if message_data is not None:
|
||||
return str(getattr(message_data, "app_id", ""))
|
||||
return ""
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
|
||||
workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"),
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
self.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
user_id = get_user_id_from_message_data(message_data)
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=user_id,
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
DIFY_APP_ID: self._extract_app_id(trace_info),
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
documents_json = serialize_json_data(documents_data)
|
||||
inputs_str = str(trace_info.inputs)
|
||||
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=inputs_str,
|
||||
outputs=documents_json,
|
||||
),
|
||||
RETRIEVAL_QUERY: inputs_str,
|
||||
RETRIEVAL_DOCUMENT: documents_json,
|
||||
},
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
message_id = trace_info.message_id
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
tool_config_json = serialize_json_data(trace_info.tool_config)
|
||||
tool_inputs_json = serialize_json_data(trace_info.tool_inputs)
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=str(trace_info.tool_outputs),
|
||||
),
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: tool_config_json,
|
||||
TOOL_PARAMETERS: tool_inputs_json,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
inputs_json = serialize_json_data(node_execution.inputs)
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TASK,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
inputs_json = serialize_json_data(node_execution.inputs or {})
|
||||
outputs_json = serialize_json_data(node_execution.outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.TOOL,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: serialize_json_data(tool_des),
|
||||
TOOL_PARAMETERS: inputs_json,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else ""
|
||||
output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else ""
|
||||
|
||||
retrieval_documents = node_execution.outputs.get("result", []) if node_execution.outputs else []
|
||||
semantic_retrieval_documents = format_retrieval_documents(retrieval_documents)
|
||||
semantic_retrieval_documents_json = serialize_json_data(semantic_retrieval_documents)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.RETRIEVER,
|
||||
inputs=input_value,
|
||||
outputs=output_value,
|
||||
),
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: semantic_retrieval_documents_json,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
prompts_json = serialize_json_data(process_data.get("prompts", []))
|
||||
text_output = str(outputs.get("text", ""))
|
||||
|
||||
gen_ai_input_message = format_input_messages(process_data)
|
||||
gen_ai_output_message = format_output_messages(outputs)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=trace_metadata.workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=prompts_json,
|
||||
outputs=text_output,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: process_data.get("model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: process_data.get("model_provider") or "",
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: prompts_json,
|
||||
GEN_AI_COMPLETION: text_output,
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
|
||||
GEN_AI_INPUT_MESSAGE: gen_ai_input_message,
|
||||
GEN_AI_OUTPUT_MESSAGE: gen_ai_output_message,
|
||||
},
|
||||
status=get_workflow_node_status(node_execution),
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
|
||||
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
|
||||
|
||||
app_id = self._extract_app_id(trace_info)
|
||||
|
||||
if message_span_id:
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
DIFY_APP_ID: app_id,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=trace_metadata.workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
**({DIFY_APP_ID: app_id} if message_span_id is None else {}),
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status = create_status_from_error(trace_info.error)
|
||||
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
workflow_span_id=0,
|
||||
session_id=trace_info.metadata.get("conversation_id") or "",
|
||||
user_id=str(trace_info.metadata.get("user_id") or ""),
|
||||
links=create_links_from_trace_id(trace_info.trace_id),
|
||||
)
|
||||
|
||||
inputs_json = serialize_json_data(trace_info.inputs)
|
||||
suggested_question_json = serialize_json_data(trace_info.suggested_question)
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.LLM,
|
||||
inputs=inputs_json,
|
||||
outputs=suggested_question_json,
|
||||
),
|
||||
GEN_AI_REQUEST_MODEL: trace_info.metadata.get("ls_model_name") or "",
|
||||
GEN_AI_PROVIDER_NAME: trace_info.metadata.get("ls_provider") or "",
|
||||
GEN_AI_PROMPT: inputs_json,
|
||||
GEN_AI_COMPLETION: suggested_question_json,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
@ -1,244 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
DEFAULT_TIMEOUT: Final[int] = 5
|
||||
DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000
|
||||
DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5
|
||||
DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TraceClient:
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
|
||||
schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC,
|
||||
max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
self.exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
self.max_queue_size = max_queue_size
|
||||
self.schedule_delay_sec = schedule_delay_sec
|
||||
self.max_export_batch_size = max_export_batch_size
|
||||
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
self.condition = threading.Condition(threading.Lock())
|
||||
self.done = False
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
self._spans_dropped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
try:
|
||||
response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.warning("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData | None) -> None:
|
||||
if span_data is None:
|
||||
return
|
||||
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
if not self._spans_dropped:
|
||||
logger.warning("Queue is full, likely spans will be dropped.")
|
||||
self._spans_dropped = True
|
||||
|
||||
self.queue.appendleft(span)
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self) -> None:
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self) -> None:
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
spans_to_export.append(self.queue.pop())
|
||||
|
||||
if spans_to_export:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.warning("Error exporting spans: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self._export_batch()
|
||||
self.exporter.shutdown()
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource: Resource) -> None:
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def build_span(self, span_data: SpanData) -> ReadableSpan:
|
||||
span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
parent_span_context = None
|
||||
if span_data.parent_span_id is not None:
|
||||
parent_span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.parent_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
span = ReadableSpan(
|
||||
name=span_data.name,
|
||||
context=span_context,
|
||||
parent=parent_span_context,
|
||||
resource=self.resource,
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=span_data.span_kind,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
instrumentation_scope=self.instrumentation_scope,
|
||||
)
|
||||
return span
|
||||
|
||||
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
placeholder_span_id = INVALID_SPAN_ID
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
|
||||
)
|
||||
|
||||
return Link(span_context)
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
||||
def convert_string_to_id(string: str | None) -> int:
|
||||
if not string:
|
||||
return generate_span_id()
|
||||
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
|
||||
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
if uuid_v4 is None:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
return convert_string_to_id(combined_key)
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
return int(timestamp_in_seconds * 1e9)
|
||||
|
||||
|
||||
def build_endpoint(base_url: str, license_key: str) -> str:
|
||||
if "log.aliyuncs.com" in base_url: # cms2.0 endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/v1/traces")
|
||||
else: # xtrace endpoint
|
||||
return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces")
|
||||
@ -1,37 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceMetadata:
|
||||
"""Metadata for trace operations, containing common attributes for all spans in a trace."""
|
||||
|
||||
trace_id: int
|
||||
workflow_span_id: int
|
||||
session_id: str
|
||||
user_id: str
|
||||
links: list[trace_api.Link]
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
"""Data model for span information in Aliyun trace system."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
|
||||
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
|
||||
@ -1,51 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
|
||||
|
||||
# Dify-specific attributes
|
||||
DIFY_APP_ID: Final[str] = "dify.app_id"
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name"
|
||||
GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind"
|
||||
GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework"
|
||||
|
||||
# Chain attributes
|
||||
INPUT_VALUE: Final[str] = "input.value"
|
||||
OUTPUT_VALUE: Final[str] = "output.value"
|
||||
|
||||
# Retriever attributes
|
||||
RETRIEVAL_QUERY: Final[str] = "retrieval.query"
|
||||
RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document"
|
||||
|
||||
# LLM attributes
|
||||
GEN_AI_REQUEST_MODEL: Final[str] = "gen_ai.request.model"
|
||||
GEN_AI_PROVIDER_NAME: Final[str] = "gen_ai.provider.name"
|
||||
GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens"
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens"
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens"
|
||||
GEN_AI_PROMPT: Final[str] = "gen_ai.prompt"
|
||||
GEN_AI_COMPLETION: Final[str] = "gen_ai.completion"
|
||||
GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason"
|
||||
|
||||
GEN_AI_INPUT_MESSAGE: Final[str] = "gen_ai.input.messages"
|
||||
GEN_AI_OUTPUT_MESSAGE: Final[str] = "gen_ai.output.messages"
|
||||
|
||||
# Tool attributes
|
||||
TOOL_NAME: Final[str] = "tool.name"
|
||||
TOOL_DESCRIPTION: Final[str] = "tool.description"
|
||||
TOOL_PARAMETERS: Final[str] = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(StrEnum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
LLM = "LLM"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
TOOL = "TOOL"
|
||||
AGENT = "AGENT"
|
||||
TASK = "TASK"
|
||||
@ -1,200 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models import EndUser
|
||||
|
||||
# Constants
|
||||
DEFAULT_JSON_ENSURE_ASCII = False
|
||||
DEFAULT_FRAMEWORK_NAME = "dify"
|
||||
|
||||
|
||||
def get_user_id_from_message_data(message_data) -> str:
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
return user_id
|
||||
|
||||
|
||||
def create_status_from_error(error: str | None) -> Status:
|
||||
if error:
|
||||
return Status(StatusCode.ERROR, error)
|
||||
return Status(StatusCode.OK)
|
||||
|
||||
|
||||
def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return Status(StatusCode.OK)
|
||||
if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return Status(StatusCode.UNSET)
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
links.append(create_link(trace_id_str=trace_id))
|
||||
return links
|
||||
|
||||
|
||||
class RetrievalDocumentMetadataDict(TypedDict):
|
||||
dataset_id: Any
|
||||
doc_id: Any
|
||||
document_id: Any
|
||||
|
||||
|
||||
class RetrievalDocumentDict(TypedDict):
|
||||
content: str
|
||||
metadata: RetrievalDocumentMetadataDict
|
||||
score: Any
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
|
||||
documents_data: list[RetrievalDocumentDict] = []
|
||||
for document in documents:
|
||||
document_data: RetrievalDocumentDict = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
|
||||
|
||||
def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
|
||||
return json.dumps(data, ensure_ascii=ensure_ascii)
|
||||
|
||||
|
||||
def create_common_span_attributes(
|
||||
session_id: str = "",
|
||||
user_id: str = "",
|
||||
span_kind: str = GenAISpanKind.CHAIN,
|
||||
framework: str = DEFAULT_FRAMEWORK_NAME,
|
||||
inputs: str = "",
|
||||
outputs: str = "",
|
||||
) -> dict[str, str]:
|
||||
return {
|
||||
GEN_AI_SESSION_ID: session_id,
|
||||
GEN_AI_USER_ID: user_id,
|
||||
GEN_AI_SPAN_KIND: span_kind,
|
||||
GEN_AI_FRAMEWORK: framework,
|
||||
INPUT_VALUE: inputs,
|
||||
OUTPUT_VALUE: outputs,
|
||||
}
|
||||
|
||||
|
||||
def format_retrieval_documents(retrieval_documents: list) -> list:
|
||||
try:
|
||||
if not isinstance(retrieval_documents, list):
|
||||
return []
|
||||
|
||||
semantic_documents = []
|
||||
for doc in retrieval_documents:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
score = metadata.get("score", 0.0)
|
||||
document_id = metadata.get("document_id", "")
|
||||
|
||||
semantic_metadata = {}
|
||||
if title:
|
||||
semantic_metadata["title"] = title
|
||||
if metadata.get("source"):
|
||||
semantic_metadata["source"] = metadata["source"]
|
||||
elif metadata.get("_source"):
|
||||
semantic_metadata["source"] = metadata["_source"]
|
||||
if metadata.get("doc_metadata"):
|
||||
doc_metadata = metadata["doc_metadata"]
|
||||
if isinstance(doc_metadata, dict):
|
||||
semantic_metadata.update(doc_metadata)
|
||||
|
||||
semantic_doc = {
|
||||
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
|
||||
}
|
||||
semantic_documents.append(semantic_doc)
|
||||
|
||||
return semantic_documents
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def format_input_messages(process_data: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(process_data, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
prompts = process_data.get("prompts", [])
|
||||
if not prompts:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_roles = {"system", "user", "assistant", "tool"}
|
||||
input_messages = []
|
||||
for prompt in prompts:
|
||||
if not isinstance(prompt, dict):
|
||||
continue
|
||||
|
||||
role = prompt.get("role", "")
|
||||
text = prompt.get("text", "")
|
||||
|
||||
if not role or role not in valid_roles:
|
||||
continue
|
||||
|
||||
if text:
|
||||
message = {"role": role, "parts": [{"type": "text", "content": text}]}
|
||||
input_messages.append(message)
|
||||
|
||||
return serialize_json_data(input_messages)
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
|
||||
|
||||
def format_output_messages(outputs: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
if not isinstance(outputs, dict):
|
||||
return serialize_json_data([])
|
||||
|
||||
text = outputs.get("text", "")
|
||||
finish_reason = outputs.get("finish_reason", "")
|
||||
|
||||
if not text:
|
||||
return serialize_json_data([])
|
||||
|
||||
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
|
||||
if finish_reason not in valid_finish_reasons:
|
||||
finish_reason = "stop"
|
||||
|
||||
output_message = {
|
||||
"role": "assistant",
|
||||
"parts": [{"type": "text", "content": text}],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
return serialize_json_data([output_message])
|
||||
except Exception:
|
||||
return serialize_json_data([])
|
||||
@ -1,848 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import (
|
||||
MessageAttributes,
|
||||
OpenInferenceMimeTypeValues,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
|
||||
# Inspect the provided endpoint to determine its structure
|
||||
parsed = urlparse(arize_phoenix_config.endpoint)
|
||||
base_endpoint = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{base_endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = GrpcOTLPSpanExporter(
|
||||
endpoint=arize_endpoint,
|
||||
headers=arize_headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{base_endpoint}{path}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=phoenix_endpoint,
|
||||
headers=phoenix_headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
attributes = {
|
||||
"openinference.project.name": arize_phoenix_config.project or "",
|
||||
"model_id": arize_phoenix_config.project or "",
|
||||
}
|
||||
resource = Resource(attributes=attributes)
|
||||
provider = trace_sdk.TracerProvider(resource=resource)
|
||||
processor = SimpleSpanProcessor(
|
||||
exporter,
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Create a named tracer instead of setting the global provider
|
||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||
logger.info("[Arize/Phoenix] Created tracer with name: %s", tracer_name)
|
||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Failed to setup the tracer: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information for Arize/Phoenix."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
string_stacktrace = "".join(traceback.format_exception(error))
|
||||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
return error_message
|
||||
|
||||
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error for Arize/Phoenix."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
||||
if isinstance(error, Exception):
|
||||
current_span.record_exception(error)
|
||||
else:
|
||||
exception_type = error.__class__.__name__
|
||||
exception_message = str(error)
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
exception_attributes.EXCEPTION_TYPE: exception_type,
|
||||
exception_attributes.EXCEPTION_MESSAGE: exception_message,
|
||||
exception_attributes.EXCEPTION_ESCAPED: False,
|
||||
exception_attributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
def wrap_span_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all trace entity types for Arize/Phoenix."""
|
||||
metadata["created_from"] = "Dify"
|
||||
metadata.update(kwargs)
|
||||
return metadata
|
||||
|
||||
|
||||
# Mapping from built-in node type strings to OpenInference span kinds.
|
||||
# Node types not listed here default to CHAIN.
|
||||
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
"llm": OpenInferenceSpanKindValues.LLM,
|
||||
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
|
||||
"tool": OpenInferenceSpanKindValues.TOOL,
|
||||
"agent": OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
|
||||
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
"""Return the OpenInference span kind for a given workflow node type.
|
||||
|
||||
Covers every built-in node type string. Nodes that do not have a
|
||||
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
|
||||
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
|
||||
"""
|
||||
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.workflow_run_status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="workflow",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
workflow_app_log_id=trace_info.workflow_app_log_id or "",
|
||||
workflow_id=trace_info.workflow_id or "",
|
||||
tenant_id=trace_info.tenant_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
|
||||
workflow_run_version=trace_info.workflow_run_version or "",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
file_list=safe_json_dumps(file_list),
|
||||
query=trace_info.query or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
for node_execution in workflow_node_executions:
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
|
||||
node_metadata.update(
|
||||
{
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"status_message": node_execution.error or "",
|
||||
"level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = _get_node_span_kind(node_execution.node_type)
|
||||
if node_execution.node_type == "llm":
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_metadata["ls_provider"] = provider
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
}
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
|
||||
return
|
||||
|
||||
file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="message",
|
||||
conversation_model=trace_info.conversation_model or "",
|
||||
message_tokens=trace_info.message_tokens or 0,
|
||||
answer_tokens=trace_info.answer_tokens or 0,
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
conversation_mode=trace_info.conversation_mode or "",
|
||||
gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
|
||||
llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
|
||||
is_streaming_request=trace_info.is_streaming_request or False,
|
||||
user_id=trace_info.message_data.from_account_id or "",
|
||||
file_list=safe_json_dumps(file_list),
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert outputs to string based on type
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = safe_json_dumps(trace_info.outputs)
|
||||
outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
|
||||
}
|
||||
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
|
||||
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
|
||||
|
||||
if trace_info.message_data.model_id is not None:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
|
||||
if trace_info.message_data.model_provider is not None:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
message_span_context = set_span_in_context(message_span)
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=message_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(llm_span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(llm_span)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(message_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(message_span)
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
|
||||
return
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="moderation",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
|
||||
{
|
||||
"flagged": trace_info.flagged,
|
||||
"action": trace_info.action,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"query": trace_info.query,
|
||||
}
|
||||
),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.status or "",
|
||||
status_message=trace_info.status_message or "",
|
||||
level=trace_info.level or "",
|
||||
trace_entity_type="suggested_question",
|
||||
total_tokens=trace_info.total_tokens or 0,
|
||||
from_account_id=trace_info.from_account_id or "",
|
||||
agent_based=trace_info.agent_based or False,
|
||||
from_source=trace_info.from_source or "",
|
||||
model_provider=trace_info.model_provider or "",
|
||||
model_id=trace_info.model_id or "",
|
||||
workflow_run_id=trace_info.workflow_run_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="dataset_retrieval",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.error or "",
|
||||
level="ERROR" if trace_info.error else "DEFAULT",
|
||||
trace_entity_type="tool",
|
||||
tool_config=safe_json_dumps(trace_info.tool_config),
|
||||
time_cost=trace_info.time_cost or 0,
|
||||
file_url=trace_info.file_url or "",
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
|
||||
return
|
||||
|
||||
metadata = wrap_span_metadata(
|
||||
trace_info.metadata,
|
||||
trace_id=trace_info.trace_id or "",
|
||||
message_id=trace_info.message_id or "",
|
||||
status=trace_info.message_data.status or "",
|
||||
status_message=trace_info.message_data.error or "",
|
||||
level="ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
trace_entity_type="generate_name",
|
||||
model_provider=trace_info.message_data.model_provider or "",
|
||||
model_id=trace_info.message_data.model_id or "",
|
||||
conversation_id=trace_info.conversation_id or "",
|
||||
tenant_id=trace_info.tenant_id,
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
span.set_attribute("test", "true")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
"""Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
|
||||
try:
|
||||
project_name = self.arize_phoenix_config.project
|
||||
endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
|
||||
|
||||
# Arize
|
||||
if isinstance(self.arize_phoenix_config, ArizeConfig):
|
||||
return f"https://app.arize.com/?redirect_project_name={project_name}"
|
||||
|
||||
# Phoenix
|
||||
return f"{endpoint}/projects/?redirect_project_name={project_name}"
|
||||
|
||||
except Exception as e:
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
def set_attribute(path: str, value: object) -> None:
|
||||
"""Store an attribute safely as a string."""
|
||||
if value is None:
|
||||
return
|
||||
try:
|
||||
if isinstance(value, (dict, list)):
|
||||
value = safe_json_dumps(value)
|
||||
attributes[path] = str(value)
|
||||
except Exception:
|
||||
attributes[path] = str(value)
|
||||
|
||||
def set_message_attribute(message_index: int, key: str, value: object) -> None:
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(
|
||||
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
|
||||
) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
def safe_get(obj, key, default=None):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
function_obj = safe_get(tool_call, "function", {})
|
||||
function_name = safe_get(function_obj, "name", "")
|
||||
function_args = safe_get(function_obj, "arguments", {})
|
||||
call_id = safe_get(tool_call, "id", "")
|
||||
|
||||
base_path = (
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}."
|
||||
f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
|
||||
)
|
||||
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
|
||||
set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
|
||||
|
||||
# Handle list of messages
|
||||
if isinstance(prompts, list):
|
||||
for message_index, message in enumerate(prompts):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role", "user")
|
||||
content = message.get("text") or message.get("content") or ""
|
||||
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
|
||||
set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_index, tool_call in enumerate(tool_calls):
|
||||
set_tool_call_attributes(message_index, tool_index, tool_call)
|
||||
|
||||
# Handle single dict or plain string prompt
|
||||
elif isinstance(prompts, (dict, str)):
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
|
||||
set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
|
||||
|
||||
return attributes
|
||||
@ -1,8 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import validate_project_name, validate_url
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -1,283 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
def validate_input_output(v, field_name):
|
||||
"""
|
||||
Validate input output
|
||||
:param v:
|
||||
:param field_name:
|
||||
:return:
|
||||
"""
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": v,
|
||||
}
|
||||
]
|
||||
elif isinstance(v, list):
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
v = replace_text_with_content(data=v)
|
||||
return v
|
||||
else:
|
||||
return [
|
||||
{
|
||||
"role": "assistant" if field_name == "output" else "user",
|
||||
"content": str(v),
|
||||
}
|
||||
]
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class LevelEnum(StrEnum):
|
||||
DEBUG = "DEBUG"
|
||||
WARNING = "WARNING"
|
||||
ERROR = "ERROR"
|
||||
DEFAULT = "DEFAULT"
|
||||
|
||||
|
||||
class LangfuseTrace(BaseModel):
|
||||
"""
|
||||
Langfuse trace model
|
||||
"""
|
||||
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
|
||||
"or when creating a distributed trace. Traces are upserted on id.",
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the trace. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
input: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
default=None, description="The input of the trace. Can be any JSON object."
|
||||
)
|
||||
output: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
default=None, description="The output of the trace. Can be any JSON object."
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
user_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
release: str | None = Field(
|
||||
default=None,
|
||||
description="The release identifier of the current deployment. Used to understand how changes of different "
|
||||
"deployments affect metrics. Useful in debugging.",
|
||||
)
|
||||
tags: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
|
||||
"API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
|
||||
)
|
||||
public: bool | None = Field(
|
||||
default=None,
|
||||
description="You can make a trace public to share it via a public link. This allows others to view the trace "
|
||||
"without needing to log in or be members of your Langfuse project.",
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class LangfuseSpan(BaseModel):
|
||||
"""
|
||||
Langfuse span model
|
||||
"""
|
||||
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
|
||||
)
|
||||
session_id: str | None = Field(
|
||||
default=None,
|
||||
description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
|
||||
)
|
||||
trace_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace the span belongs to. Used to link spans to traces.",
|
||||
)
|
||||
user_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
start_time: datetime | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the span started, defaults to the current time.",
|
||||
)
|
||||
end_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the span ended. Automatically set by span.end().",
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the span. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
level: LevelEnum | None = Field(
|
||||
default=None,
|
||||
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
||||
"traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: str | None = Field(
|
||||
default=None,
|
||||
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
input: Union[str, Mapping[str, Any], list, None] | None = Field(
|
||||
default=None, description="The input of the span. Can be any JSON object."
|
||||
)
|
||||
output: Union[str, Mapping[str, Any], list, None] | None = Field(
|
||||
default=None, description="The output of the span. Can be any JSON object."
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the span type. Used to understand how changes to the span type affect metrics. "
|
||||
"Useful in debugging.",
|
||||
)
|
||||
parent_observation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the observation the span belongs to. Used to link spans to observations.",
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class UnitEnum(StrEnum):
|
||||
CHARACTERS = "CHARACTERS"
|
||||
TOKENS = "TOKENS"
|
||||
SECONDS = "SECONDS"
|
||||
MILLISECONDS = "MILLISECONDS"
|
||||
IMAGES = "IMAGES"
|
||||
|
||||
|
||||
class GenerationUsage(BaseModel):
|
||||
promptTokens: int | None = None
|
||||
completionTokens: int | None = None
|
||||
total: int | None = None
|
||||
input: int | None = None
|
||||
output: int | None = None
|
||||
unit: UnitEnum | None = None
|
||||
inputCost: float | None = None
|
||||
outputCost: float | None = None
|
||||
totalCost: float | None = None
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
||||
|
||||
class LangfuseGeneration(BaseModel):
|
||||
id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the generation can be set, defaults to random id.",
|
||||
)
|
||||
trace_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the trace the generation belongs to. Used to link generations to traces.",
|
||||
)
|
||||
parent_observation_id: str | None = Field(
|
||||
default=None,
|
||||
description="The id of the observation the generation belongs to. Used to link generations to observations.",
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
start_time: datetime | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the generation started, defaults to the current time.",
|
||||
)
|
||||
completion_start_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
||||
"down into time until completion started and completion duration.",
|
||||
)
|
||||
end_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the generation ended. Automatically set by generation.end().",
|
||||
)
|
||||
model: str | None = Field(default=None, description="The name of the model used for the generation.")
|
||||
model_parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="The parameters of the model used for the generation; can be any key-value pairs.",
|
||||
)
|
||||
input: Any | None = Field(
|
||||
default=None,
|
||||
description="The prompt used for the generation. Can be any string or JSON object.",
|
||||
)
|
||||
output: Any | None = Field(
|
||||
default=None,
|
||||
description="The completion generated by the model. Can be any string or JSON object.",
|
||||
)
|
||||
usage: GenerationUsage | None = Field(
|
||||
default=None,
|
||||
description="The usage object supports the OpenAi structure with tokens and a more generic version with "
|
||||
"detailed costs and units.",
|
||||
)
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
|
||||
"updated via the API.",
|
||||
)
|
||||
level: LevelEnum | None = Field(
|
||||
default=None,
|
||||
description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
|
||||
"of traces with elevated error levels and for highlighting in the UI.",
|
||||
)
|
||||
status_message: str | None = Field(
|
||||
default=None,
|
||||
description="The status message of the generation. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="The version of the generation type. Used to understand how changes to the span type affect "
|
||||
"metrics. Useful in debugging.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -1,569 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
CreateGenerationBody,
|
||||
CreateSpanBody,
|
||||
IngestionEvent_GenerationCreate,
|
||||
IngestionEvent_SpanCreate,
|
||||
IngestionEvent_TraceCreate,
|
||||
TraceBody,
|
||||
)
|
||||
from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
GenerationUsage,
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangFuseDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
langfuse_config: LangfuseConfig,
|
||||
):
|
||||
super().__init__(langfuse_config)
|
||||
self.langfuse_client = Langfuse(
|
||||
public_key=langfuse_config.public_key,
|
||||
secret_key=langfuse_config.secret_key,
|
||||
host=langfuse_config.host,
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@staticmethod
|
||||
def _get_completion_start_time(
|
||||
start_time: datetime | None, time_to_first_token: float | int | None
|
||||
) -> datetime | None:
|
||||
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
|
||||
if start_time is None or time_to_first_token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if ttft_seconds < 0:
|
||||
return None
|
||||
|
||||
return start_time + timedelta(seconds=ttft_seconds)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["message", "workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
else:
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
metadata=metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
tags=["workflow"],
|
||||
version=trace_info.workflow_run_version,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"node_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
process_data = node_execution.process_data or {}
|
||||
model_provider = process_data.get("model_provider", None)
|
||||
model_name = process_data.get("model_name", None)
|
||||
if model_provider is not None and model_name is not None:
|
||||
metadata.update(
|
||||
{
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
)
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_start_time = None
|
||||
try:
|
||||
usage_data = process_data.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = outputs.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = {}
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
created_at, usage_data.get("time_to_first_token")
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
input=prompt_tokens,
|
||||
output=completion_tokens,
|
||||
total=total_token,
|
||||
unit=UnitEnum.TOKENS,
|
||||
)
|
||||
|
||||
node_generation_data = LangfuseGeneration(
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
start_time=created_at,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=node_generation_data)
|
||||
|
||||
# add normal span
|
||||
else:
|
||||
span_data = LangfuseSpan(
|
||||
id=node_execution_id,
|
||||
name=node_name,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
trace_id=trace_id,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
|
||||
# get message file data
|
||||
file_list = trace_info.file_list
|
||||
metadata = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
input={
|
||||
"message": trace_info.inputs,
|
||||
"files": file_list,
|
||||
"message_tokens": trace_info.message_tokens,
|
||||
"answer_tokens": trace_info.answer_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"error": trace_info.error,
|
||||
"provider_response_latency": message_data.provider_response_latency,
|
||||
"created_at": trace_info.start_time,
|
||||
},
|
||||
output=trace_info.outputs,
|
||||
metadata=metadata,
|
||||
session_id=message_data.conversation_id,
|
||||
tags=["message", str(trace_info.conversation_mode)],
|
||||
version=None,
|
||||
release=None,
|
||||
public=None,
|
||||
)
|
||||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
|
||||
# add generation
|
||||
generation_usage = GenerationUsage(
|
||||
input=trace_info.message_tokens,
|
||||
output=trace_info.answer_tokens,
|
||||
total=trace_info.total_tokens,
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
trace_info.start_time,
|
||||
trace_info.gen_ai_server_time_to_first_token,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
input=trace_info.inputs,
|
||||
output=message_data.answer,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
span_data = LangfuseSpan(
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.created_at,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
generation_usage = GenerationUsage(
|
||||
total=len(str(trace_info.suggested_question)),
|
||||
input=len(trace_info.inputs) if trace_info.inputs else 0,
|
||||
output=len(trace_info.suggested_question),
|
||||
unit=UnitEnum.CHARACTERS,
|
||||
)
|
||||
|
||||
generation_data = LangfuseGeneration(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=str(trace_info.suggested_question),
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
self.add_generation(langfuse_generation_data=generation_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={"documents": trace_info.documents},
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=dataset_retrieval_span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
tool_span_data = LangfuseSpan(
|
||||
name=trace_info.tool_name,
|
||||
input=trace_info.tool_inputs,
|
||||
output=trace_info.tool_outputs,
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
|
||||
status_message=trace_info.error,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=tool_span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_generation_trace_data = LangfuseTrace(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
user_id=trace_info.tenant_id,
|
||||
metadata=trace_info.metadata,
|
||||
session_id=trace_info.conversation_id,
|
||||
)
|
||||
|
||||
self.add_trace(langfuse_trace_data=name_generation_trace_data)
|
||||
|
||||
name_generation_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
trace_id=trace_info.conversation_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
)
|
||||
self.add_span(langfuse_span_data=name_generation_span_data)
|
||||
|
||||
def _make_event_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _now_iso(self) -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
|
||||
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
try:
|
||||
body = TraceBody(
|
||||
id=data.get("id"),
|
||||
name=data.get("name"),
|
||||
user_id=data.get("user_id"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
session_id=data.get("session_id"),
|
||||
version=data.get("version"),
|
||||
release=data.get("release"),
|
||||
tags=data.get("tags"),
|
||||
public=data.get("public"),
|
||||
)
|
||||
event = IngestionEvent_TraceCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Trace created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
|
||||
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
try:
|
||||
body = CreateSpanBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
)
|
||||
event = IngestionEvent_SpanCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
||||
|
||||
def update_span(self, span, langfuse_span_data: LangfuseSpan | None = None):
|
||||
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
|
||||
span.end(**format_span_data)
|
||||
|
||||
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
try:
|
||||
usage_data = data.pop("usage", None)
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
input=usage_data.get("input", 0) or 0,
|
||||
output=usage_data.get("output", 0) or 0,
|
||||
total=usage_data.get("total", 0) or 0,
|
||||
unit=usage_data.get("unit"),
|
||||
input_cost=usage_data.get("inputCost"),
|
||||
output_cost=usage_data.get("outputCost"),
|
||||
total_cost=usage_data.get("totalCost"),
|
||||
)
|
||||
|
||||
body = CreateGenerationBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
model=data.get("model"),
|
||||
model_parameters=data.get("model_parameters"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
usage=usage,
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
completion_start_time=data.get("completion_start_time"),
|
||||
)
|
||||
event = IngestionEvent_GenerationCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Generation created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
||||
|
||||
def update_generation(self, generation, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
)
|
||||
|
||||
generation.end(**format_generation_data)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
return self.langfuse_client.auth_check()
|
||||
except Exception as e:
|
||||
logger.debug("LangFuse API check failed: %s", str(e))
|
||||
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
||||
|
||||
def get_project_key(self):
|
||||
try:
|
||||
projects = self.langfuse_client.api.projects.get()
|
||||
return projects.data[0].id
|
||||
except Exception as e:
|
||||
logger.debug("LangFuse get project key failed: %s", str(e))
|
||||
raise ValueError(f"LangFuse get project key failed: {str(e)}")
|
||||
@ -1,142 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class LangSmithRunType(StrEnum):
|
||||
tool = "tool"
|
||||
chain = "chain"
|
||||
llm = "llm"
|
||||
retriever = "retriever"
|
||||
embedding = "embedding"
|
||||
prompt = "prompt"
|
||||
parser = "parser"
|
||||
|
||||
|
||||
class LangSmithTokenUsage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class LangSmithMultiModel(BaseModel):
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
name: str | None = Field(..., description="Name of the run")
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the run")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the run")
|
||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||
start_time: datetime | str | None = Field(None, description="Start time of the run")
|
||||
end_time: datetime | str | None = Field(None, description="End time of the run")
|
||||
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
|
||||
error: str | None = Field(None, description="Error message of the run")
|
||||
serialized: dict[str, Any] | None = Field(None, description="Serialized data of the run")
|
||||
parent_run_id: str | None = Field(None, description="Parent run ID")
|
||||
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
|
||||
tags: list[str] | None = Field(None, description="Tags associated with the run")
|
||||
trace_id: str | None = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: str | None = Field(None, description="Dotted order of the run")
|
||||
id: str | None = Field(None, description="ID of the run")
|
||||
session_id: str | None = Field(None, description="Session ID associated with the run")
|
||||
session_name: str | None = Field(None, description="Session name associated with the run")
|
||||
reference_example_id: str | None = Field(None, description="Reference example ID associated with the run")
|
||||
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
|
||||
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("start_time", "end_time")
|
||||
def format_time(cls, v, info: ValidationInfo):
|
||||
if not isinstance(v, datetime):
|
||||
raise ValueError(f"{info.field_name} must be a datetime object")
|
||||
else:
|
||||
return v.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
|
||||
class LangSmithRunUpdateModel(BaseModel):
|
||||
run_id: str = Field(..., description="ID of the run")
|
||||
trace_id: str | None = Field(None, description="Trace ID associated with the run")
|
||||
dotted_order: str | None = Field(None, description="Dotted order of the run")
|
||||
parent_run_id: str | None = Field(None, description="Parent run ID")
|
||||
end_time: datetime | str | None = Field(None, description="End time of the run")
|
||||
error: str | None = Field(None, description="Error message of the run")
|
||||
inputs: dict[str, Any] | None = Field(None, description="Inputs of the run")
|
||||
outputs: dict[str, Any] | None = Field(None, description="Outputs of the run")
|
||||
events: list[dict[str, Any]] | None = Field(None, description="Events associated with the run")
|
||||
tags: list[str] | None = Field(None, description="Tags associated with the run")
|
||||
extra: dict[str, Any] | None = Field(None, description="Extra information of the run")
|
||||
input_attachments: dict[str, Any] | None = Field(None, description="Input attachments of the run")
|
||||
output_attachments: dict[str, Any] | None = Field(None, description="Output attachments of the run")
|
||||
@ -1,523 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangSmithDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
langsmith_config: LangSmithConfig,
|
||||
):
|
||||
super().__init__(langsmith_config)
|
||||
self.langsmith_key = langsmith_config.api_key
|
||||
self.project_name = langsmith_config.project
|
||||
self.project_id = None
|
||||
self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
||||
)
|
||||
workflow_dotted_order = generate_dotted_order(
|
||||
trace_info.workflow_run_id,
|
||||
trace_info.workflow_data.created_at,
|
||||
message_dotted_order,
|
||||
)
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
tags=["message", "workflow"],
|
||||
error=trace_info.error,
|
||||
trace_id=trace_id,
|
||||
dotted_order=message_dotted_order,
|
||||
file_list=[],
|
||||
serialized=None,
|
||||
parent_run_id=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
langsmith_run = LangSmithRunModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=trace_info.workflow_data.created_at,
|
||||
end_time=trace_info.workflow_data.finished_at,
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
||||
langsmith_run = LangSmithRunModel(
|
||||
total_tokens=node_total_tokens,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
name=node_type,
|
||||
inputs=inputs,
|
||||
run_type=run_type,
|
||||
start_time=created_at,
|
||||
end_time=finished_at,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
extra={
|
||||
"metadata": metadata,
|
||||
},
|
||||
parent_run_id=trace_info.workflow_run_id,
|
||||
tags=["node_execution"],
|
||||
id=node_execution_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=node_dotted_order,
|
||||
error="",
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
metadata = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
message_run = LangSmithRunModel(
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={"metadata": metadata},
|
||||
tags=["message", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
parent_run_id=None,
|
||||
)
|
||||
self.add_run(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = LangSmithRunModel(
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
name="llm",
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.llm,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
outputs=message_data.answer,
|
||||
extra={"metadata": metadata},
|
||||
parent_run_id=message_id,
|
||||
tags=["llm", str(trace_info.conversation_mode)],
|
||||
error=trace_info.error,
|
||||
file_list=file_list,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
self.add_run(llm_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
langsmith_run = LangSmithRunModel(
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["moderation"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(langsmith_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
suggested_question_run = LangSmithRunModel(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["suggested_question"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or message_data.created_at,
|
||||
end_time=trace_info.end_time or message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_run = LangSmithRunModel(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
run_type=LangSmithRunType.retriever,
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["dataset_retrieval"],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.add_run(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
tool_run = LangSmithRunModel(
|
||||
name=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={
|
||||
"metadata": trace_info.metadata,
|
||||
},
|
||||
tags=["tool", trace_info.tool_name],
|
||||
parent_run_id=trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
file_list=[cast(str, trace_info.file_url)],
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error=trace_info.error or "",
|
||||
)
|
||||
|
||||
self.add_run(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_run = LangSmithRunModel(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
extra={"metadata": trace_info.metadata},
|
||||
tags=["generate_name"],
|
||||
start_time=trace_info.start_time or datetime.now(),
|
||||
end_time=trace_info.end_time or datetime.now(),
|
||||
id=str(uuid.uuid4()),
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
session_name=None,
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
parent_run_id=None,
|
||||
)
|
||||
|
||||
self.add_run(name_run)
|
||||
|
||||
def add_run(self, run_data: LangSmithRunModel):
|
||||
data = run_data.model_dump()
|
||||
if self.project_id:
|
||||
data["session_id"] = self.project_id
|
||||
elif self.project_name:
|
||||
data["session_name"] = self.project_name
|
||||
|
||||
data = filter_none_values(data)
|
||||
try:
|
||||
self.langsmith_client.create_run(**data)
|
||||
logger.debug("LangSmith Run created successfully.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangSmith Failed to create run: {str(e)}")
|
||||
|
||||
def update_run(self, update_run_data: LangSmithRunUpdateModel):
|
||||
data = update_run_data.model_dump()
|
||||
data = filter_none_values(data)
|
||||
try:
|
||||
self.langsmith_client.update_run(**data)
|
||||
logger.debug("LangSmith Run updated successfully.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangSmith Failed to update run: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
self.langsmith_client.create_project(project_name=random_project_name)
|
||||
self.langsmith_client.delete_project(project_name=random_project_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("LangSmith API check failed: %s", str(e))
|
||||
raise ValueError(f"LangSmith API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
run_data = RunBase(
|
||||
id=uuid.uuid4(),
|
||||
name="tool",
|
||||
inputs={"input": "test"},
|
||||
outputs={"output": "test"},
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
project_url = self.langsmith_client.get_run_url(
|
||||
run=run_data, project_id=self.project_id, project_name=self.project_name
|
||||
)
|
||||
return project_url.split("/r/")[0]
|
||||
except Exception as e:
|
||||
logger.debug("LangSmith get run url failed: %s", str(e))
|
||||
raise ValueError(f"LangSmith get run url failed: {str(e)}")
|
||||
@ -1,536 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
|
||||
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
|
||||
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
|
||||
"""Convert datetime to nanosecond timestamp for MLflow API"""
|
||||
if dt is None:
|
||||
return None
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
class MLflowDataTrace(BaseTraceInstance):
|
||||
def __init__(self, config: MLflowConfig | DatabricksConfig):
|
||||
super().__init__(config)
|
||||
if isinstance(config, DatabricksConfig):
|
||||
self._setup_databricks(config)
|
||||
else:
|
||||
self._setup_mlflow(config)
|
||||
|
||||
# Enable async logging to minimize performance overhead
|
||||
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
|
||||
|
||||
def _setup_databricks(self, config: DatabricksConfig):
|
||||
"""Setup connection to Databricks-managed MLflow instances"""
|
||||
os.environ["DATABRICKS_HOST"] = config.host
|
||||
|
||||
if config.client_id and config.client_secret:
|
||||
# OAuth: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m?language=Environment
|
||||
os.environ["DATABRICKS_CLIENT_ID"] = config.client_id
|
||||
os.environ["DATABRICKS_CLIENT_SECRET"] = config.client_secret
|
||||
elif config.personal_access_token:
|
||||
# PAT: https://docs.databricks.com/aws/en/dev-tools/auth/pat
|
||||
os.environ["DATABRICKS_TOKEN"] = config.personal_access_token
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either Databricks token (PAT) or client id and secret (OAuth) must be provided"
|
||||
"See https://docs.databricks.com/aws/en/dev-tools/auth/#what-authorization-option-should-i-choose "
|
||||
"for more information about the authorization options."
|
||||
)
|
||||
mlflow.set_tracking_uri("databricks")
|
||||
mlflow.set_experiment(experiment_id=config.experiment_id)
|
||||
|
||||
# Remove trailing slash from host
|
||||
config.host = config.host.rstrip("/")
|
||||
self._project_url = f"{config.host}/ml/experiments/{config.experiment_id}/traces"
|
||||
|
||||
def _setup_mlflow(self, config: MLflowConfig):
|
||||
"""Setup connection to MLflow instances"""
|
||||
mlflow.set_tracking_uri(config.tracking_uri)
|
||||
mlflow.set_experiment(experiment_id=config.experiment_id)
|
||||
|
||||
# Simple auth if provided
|
||||
if config.username and config.password:
|
||||
os.environ["MLFLOW_TRACKING_USERNAME"] = config.username
|
||||
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.password
|
||||
|
||||
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
"""Simple dispatch to trace methods"""
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
except Exception:
|
||||
logger.exception("[MLflow] Trace error")
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
"""Create workflow span as root, with node spans as children"""
|
||||
# fields with sys.xyz is added by Dify, they are duplicate to trace_info.metadata
|
||||
raw_inputs = trace_info.workflow_run_inputs or {}
|
||||
workflow_inputs = {k: v for k, v in raw_inputs.items() if not k.startswith("sys.")}
|
||||
|
||||
# Special inputs propagated by system
|
||||
if trace_info.query:
|
||||
workflow_inputs["query"] = trace_info.query
|
||||
|
||||
workflow_span = start_span_no_context(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=workflow_inputs,
|
||||
attributes=trace_info.metadata,
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
# Set reserved fields in trace-level metadata
|
||||
trace_metadata = {}
|
||||
if user_id := trace_info.metadata.get("user_id"):
|
||||
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
|
||||
if session_id := trace_info.conversation_id:
|
||||
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
|
||||
self._set_trace_metadata(workflow_span, trace_metadata)
|
||||
|
||||
try:
|
||||
# Create child spans for workflow nodes
|
||||
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
inputs = None
|
||||
attributes = {
|
||||
"node_id": node.id,
|
||||
"node_type": node.node_type,
|
||||
"status": node.status,
|
||||
"tenant_id": node.tenant_id,
|
||||
"app_id": node.app_id,
|
||||
"app_name": node.title,
|
||||
}
|
||||
|
||||
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
|
||||
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
|
||||
attributes.update(llm_attributes)
|
||||
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
inputs = node.process_data # contains request URL
|
||||
|
||||
if not inputs:
|
||||
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
|
||||
|
||||
node_span = start_span_no_context(
|
||||
name=node.title,
|
||||
span_type=self._get_node_span_type(node.node_type),
|
||||
parent_span=workflow_span,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time_ns=datetime_to_nanoseconds(node.created_at),
|
||||
)
|
||||
|
||||
# Handle node errors
|
||||
if node.status != "succeeded":
|
||||
node_span.set_status(SpanStatusCode.ERROR)
|
||||
node_span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="exception",
|
||||
attributes={
|
||||
"exception.message": f"Node failed with status: {node.status}",
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": f"Node failed with status: {node.status}",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# End node span
|
||||
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
|
||||
outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {}
|
||||
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
outputs = self._parse_knowledge_retrieval_outputs(outputs)
|
||||
elif node.node_type == BuiltinNodeTypes.LLM:
|
||||
outputs = outputs.get("text", outputs)
|
||||
node_span.end(
|
||||
outputs=outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(finished_at),
|
||||
)
|
||||
|
||||
# Handle workflow-level errors
|
||||
if trace_info.error:
|
||||
workflow_span.set_status(SpanStatusCode.ERROR)
|
||||
workflow_span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
workflow_span.end(
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
|
||||
"""Parse LLM inputs and attributes from LLM workflow node"""
|
||||
if node.process_data is None:
|
||||
return {}, {}
|
||||
|
||||
try:
|
||||
data = JSON_DICT_ADAPTER.validate_json(node.process_data)
|
||||
except (ValueError, TypeError):
|
||||
return {}, {}
|
||||
|
||||
inputs = self._parse_prompts(data.get("prompts"))
|
||||
attributes = {
|
||||
"model_name": data.get("model_name"),
|
||||
"model_provider": data.get("model_provider"),
|
||||
"finish_reason": data.get("finish_reason"),
|
||||
}
|
||||
|
||||
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
|
||||
attributes[SpanAttributeKey.MESSAGE_FORMAT] = "dify"
|
||||
|
||||
if usage := data.get("usage"):
|
||||
# Set reserved token usage attributes
|
||||
attributes[SpanAttributeKey.CHAT_USAGE] = {
|
||||
TokenUsageKey.INPUT_TOKENS: usage.get("prompt_tokens", 0),
|
||||
TokenUsageKey.OUTPUT_TOKENS: usage.get("completion_tokens", 0),
|
||||
TokenUsageKey.TOTAL_TOKENS: usage.get("total_tokens", 0),
|
||||
}
|
||||
# Store raw usage data as well as it includes more data like price
|
||||
attributes["usage"] = usage
|
||||
|
||||
return inputs, attributes
|
||||
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
|
||||
"""Parse KR outputs and attributes from KR workflow node"""
|
||||
retrieved = outputs.get("result", [])
|
||||
|
||||
if not retrieved or not isinstance(retrieved, list):
|
||||
return outputs
|
||||
|
||||
documents = []
|
||||
for item in retrieved:
|
||||
documents.append(Document(page_content=item.get("content", ""), metadata=item.get("metadata", {})))
|
||||
return documents
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
"""Create span for CHATBOT message processing"""
|
||||
if not trace_info.message_data:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
if message_file_data := trace_info.message_file_data:
|
||||
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
file_list.append(f"{base_url}/{message_file_data.url}")
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
span_type=SpanType.LLM,
|
||||
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"model_provider": trace_info.message_data.model_provider,
|
||||
"model_id": trace_info.message_data.model_id,
|
||||
"conversation_mode": trace_info.conversation_mode,
|
||||
"file_list": file_list, # type: ignore[dict-item]
|
||||
"total_price": trace_info.message_data.total_price,
|
||||
**trace_info.metadata,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
|
||||
span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "dify")
|
||||
|
||||
# Set token usage
|
||||
span.set_attribute(
|
||||
SpanAttributeKey.CHAT_USAGE,
|
||||
{
|
||||
TokenUsageKey.INPUT_TOKENS: trace_info.message_tokens or 0,
|
||||
TokenUsageKey.OUTPUT_TOKENS: trace_info.answer_tokens or 0,
|
||||
TokenUsageKey.TOTAL_TOKENS: trace_info.total_tokens or 0,
|
||||
},
|
||||
)
|
||||
|
||||
# Set reserved fields in trace-level metadata
|
||||
trace_metadata = {}
|
||||
if user_id := self._get_message_user_id(trace_info.metadata):
|
||||
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
|
||||
if session_id := trace_info.metadata.get("conversation_id"):
|
||||
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
|
||||
self._set_trace_metadata(span, trace_metadata)
|
||||
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs=trace_info.message_data.answer,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.get(EndUser, end_user_id)
|
||||
):
|
||||
return end_user_data.session_id
|
||||
|
||||
return metadata.get("from_account_id") # type: ignore[return-value]
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span = start_span_no_context(
|
||||
name=trace_info.tool_name,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
|
||||
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
# Handle tool errors
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs=trace_info.tool_outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs or {},
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
},
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
span_type=SpanType.RETRIEVER,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
span.end(outputs={"documents": trace_info.documents}, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
|
||||
"model_id": trace_info.model_id, # type: ignore[dict-item]
|
||||
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowNodeExecutionModel.created_at)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _get_node_span_type(self, node_type: str) -> str:
|
||||
"""Map Dify node types to MLflow span types"""
|
||||
node_type_mapping = {
|
||||
BuiltinNodeTypes.LLM: SpanType.LLM,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
BuiltinNodeTypes.TOOL: SpanType.TOOL,
|
||||
BuiltinNodeTypes.CODE: SpanType.TOOL,
|
||||
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
|
||||
BuiltinNodeTypes.AGENT: SpanType.AGENT,
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
|
||||
token = None
|
||||
try:
|
||||
# NB: Set span in context such that we can use update_current_trace() API
|
||||
token = set_span_in_context(span)
|
||||
update_current_trace(metadata=metadata)
|
||||
finally:
|
||||
if token:
|
||||
detach_span_from_context(token)
|
||||
|
||||
def _parse_prompts(self, prompts):
|
||||
"""Postprocess prompts format to be standard chat messages"""
|
||||
if isinstance(prompts, str):
|
||||
return prompts
|
||||
elif isinstance(prompts, dict):
|
||||
return self._parse_single_message(prompts)
|
||||
elif isinstance(prompts, list):
|
||||
messages = [self._parse_single_message(item) for item in prompts]
|
||||
messages = self._resolve_tool_call_ids(messages)
|
||||
return messages
|
||||
return prompts # Fallback to original format
|
||||
|
||||
def _parse_single_message(self, item: dict[str, Any]):
|
||||
"""Postprocess single message format to be standard chat message"""
|
||||
role = item.get("role", "user")
|
||||
msg = {"role": role, "content": item.get("text", "")}
|
||||
|
||||
if (
|
||||
(tool_calls := item.get("tool_calls"))
|
||||
# Tool message does not contain tool calls normally
|
||||
and role != "tool"
|
||||
):
|
||||
msg["tool_calls"] = tool_calls
|
||||
|
||||
if files := item.get("files"):
|
||||
msg["files"] = files
|
||||
|
||||
return msg
|
||||
|
||||
def _resolve_tool_call_ids(self, messages: list[dict]):
|
||||
"""
|
||||
The tool call message from Dify does not contain tool call ids, which is not
|
||||
ideal for debugging. This method resolves the tool call ids by matching the
|
||||
tool call name and parameters with the tool instruction messages.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
for msg in messages:
|
||||
if tool_calls := msg.get("tool_calls"):
|
||||
tool_call_ids = [t["id"] for t in tool_calls]
|
||||
if msg["role"] == "tool":
|
||||
# Get the tool call id in the order of the tool call messages
|
||||
# assuming Dify runs tools sequentially
|
||||
if tool_call_ids:
|
||||
msg["tool_call_id"] = tool_call_ids.pop(0)
|
||||
return messages
|
||||
|
||||
def api_check(self):
|
||||
"""Simple connection test"""
|
||||
try:
|
||||
mlflow.search_experiments(max_results=1)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(f"MLflow connection failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
return self._project_url
|
||||
@ -1,467 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_dict(key_name, data):
|
||||
"""Make sure that the input data is a dict"""
|
||||
if not isinstance(data, dict):
|
||||
return {key_name: data}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def wrap_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all Traces and Spans"""
|
||||
metadata["created_from"] = "dify"
|
||||
|
||||
metadata.update(kwargs)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _seed_to_uuid4(seed: str) -> str:
|
||||
"""Derive a deterministic UUID4-formatted string from an arbitrary seed.
|
||||
|
||||
uuid4_to_uuid7 requires a valid UUID v4 string, but some Dify identifiers
|
||||
are not UUIDs (e.g. a workflow_run_id with a "-root" suffix appended to
|
||||
distinguish the root span from the trace). This helper hashes the seed
|
||||
with MD5 and patches the version/variant bits so the result satisfies the
|
||||
UUID v4 contract.
|
||||
"""
|
||||
raw = hashlib.md5(seed.encode()).digest()
|
||||
ba = bytearray(raw)
|
||||
ba[6] = (ba[6] & 0x0F) | 0x40 # version 4
|
||||
ba[8] = (ba[8] & 0x3F) | 0x80 # variant 1
|
||||
return str(uuid.UUID(bytes=bytes(ba)))
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
it to a UUIDv7. Given that we have no way to identify that object
|
||||
uniquely, generate a new random one UUIDv7 in that case.
|
||||
"""
|
||||
|
||||
if user_datetime is None:
|
||||
user_datetime = datetime.now()
|
||||
|
||||
if user_uuid is None:
|
||||
user_uuid = str(uuid.uuid4())
|
||||
|
||||
return uuid4_to_uuid7(user_datetime, user_uuid)
|
||||
|
||||
|
||||
class OpikDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
opik_config: OpikConfig,
|
||||
):
|
||||
super().__init__(opik_config)
|
||||
self.opik_client = Opik(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
trace_name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_tags = ["message", "workflow"]
|
||||
root_span_seed = trace_info.workflow_run_id
|
||||
else:
|
||||
dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
trace_name = TraceTaskName.WORKFLOW_TRACE
|
||||
trace_tags = ["workflow"]
|
||||
root_span_seed = _seed_to_uuid4(trace_info.workflow_run_id + "-root")
|
||||
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": trace_name,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"thread_id": trace_info.conversation_id,
|
||||
"tags": trace_tags,
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, root_span_seed)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
total_tokens = 0
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = "llm"
|
||||
provider = process_data.get("model_provider", None)
|
||||
model = process_data.get("model_name", "")
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": provider,
|
||||
"ls_model_name": model,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
total_tokens = usage_data.get("total_tokens", 0)
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": root_span_id,
|
||||
"name": node_name,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": wrap_dict("input", inputs),
|
||||
"output": wrap_dict("output", outputs),
|
||||
"tags": ["node_execution"],
|
||||
"project_name": self.project,
|
||||
"usage": {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": message_data.answer,
|
||||
"thread_id": message_data.conversation_id,
|
||||
"tags": ["message", str(trace_info.conversation_mode)],
|
||||
"project_name": self.project,
|
||||
}
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": "llm",
|
||||
"type": "llm",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": {"input": trace_info.inputs},
|
||||
"output": {"output": message_data.answer},
|
||||
"tags": ["llm", str(trace_info.conversation_mode)],
|
||||
"usage": {
|
||||
"completion_tokens": trace_info.answer_tokens,
|
||||
"prompt_tokens": trace_info.message_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
},
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
"tags": ["moderation"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.suggested_question),
|
||||
"tags": ["suggested_question"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {"documents": trace_info.documents},
|
||||
"tags": ["dataset_retrieval"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.tool_inputs),
|
||||
"output": wrap_dict("output", trace_info.tool_outputs),
|
||||
"tags": ["tool", trace_info.tool_name],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": trace_info.outputs,
|
||||
"thread_id": trace_info.conversation_id,
|
||||
"tags": ["generate_name"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.outputs),
|
||||
"tags": ["generate_name"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
return trace
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict[str, Any]):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
self.opik_client.auth_check()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info("Opik API check failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.opik_client.get_project_url(project_name=self.project)
|
||||
except Exception as e:
|
||||
logger.info("Opik get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
||||
@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
try:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.OPIK:
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
case TracingProviderEnum.OPIK:
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace
|
||||
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from dify_trace_weave.config import WeaveConfig
|
||||
from dify_trace_weave.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig
|
||||
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import PhoenixConfig
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import PhoenixConfig
|
||||
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from core.ops.entities.config_entity import MLflowConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from dify_trace_mlflow.config import MLflowConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from core.ops.entities.config_entity import DatabricksConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from dify_trace_mlflow.config import DatabricksConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
from dify_trace_tencent.tencent_trace import TencentDataTrace
|
||||
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
except ImportError:
|
||||
raise ImportError(f"Provider {provider} is not installed.")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@ -1,571 +0,0 @@
|
||||
"""
|
||||
Tencent APM Trace Client - handles network operations, metrics, and API communication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
except ImportError:
|
||||
from importlib_metadata import version # type: ignore[import-not-found]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Histogram, Meter
|
||||
from opentelemetry.sdk.metrics.export import MetricReader
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .entities.semconv import (
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
GEN_AI_TRACE_DURATION,
|
||||
LLM_OPERATION_DURATION,
|
||||
)
|
||||
from .entities.tencent_trace_entity import SpanData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_opentelemetry_sdk_version() -> str:
|
||||
"""Get OpenTelemetry SDK version dynamically."""
|
||||
try:
|
||||
return version("opentelemetry-sdk")
|
||||
except Exception:
|
||||
logger.debug("Failed to get opentelemetry-sdk version, using default")
|
||||
return "1.27.0" # fallback version
|
||||
|
||||
|
||||
class TencentTraceClient:
|
||||
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
token: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
metrics_export_interval_sec: int = 10,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.service_name = service_name
|
||||
self.metrics_export_interval_sec = metrics_export_interval_sec
|
||||
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
"telemetry.sdk.language": "python",
|
||||
"telemetry.sdk.name": "opentelemetry",
|
||||
"telemetry.sdk.version": _get_opentelemetry_sdk_version(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||
|
||||
headers = (("authorization", f"Bearer {token}"),)
|
||||
|
||||
self.exporter = OTLPSpanExporter(
|
||||
endpoint=grpc_endpoint,
|
||||
headers=headers,
|
||||
insecure=insecure,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
self.tracer_provider = TracerProvider(resource=self.resource)
|
||||
self.span_processor = BatchSpanProcessor(
|
||||
span_exporter=self.exporter,
|
||||
max_queue_size=max_queue_size,
|
||||
schedule_delay_millis=schedule_delay_sec * 1000,
|
||||
max_export_batch_size=max_export_batch_size,
|
||||
)
|
||||
self.tracer_provider.add_span_processor(self.span_processor)
|
||||
|
||||
# use dify api version as tracer version
|
||||
self.tracer = self.tracer_provider.get_tracer("dify-sdk", dify_config.project.version)
|
||||
|
||||
# Store span contexts for parent-child relationships
|
||||
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||
|
||||
self.meter: Meter | None = None
|
||||
self.meter_provider: MeterProvider | None = None
|
||||
self.hist_llm_duration: Histogram | None = None
|
||||
self.hist_token_usage: Histogram | None = None
|
||||
self.hist_time_to_first_token: Histogram | None = None
|
||||
self.hist_time_to_generate: Histogram | None = None
|
||||
self.hist_trace_duration: Histogram | None = None
|
||||
self.metric_reader: MetricReader | None = None
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry.sdk.metrics import Histogram as SdkHistogram
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
|
||||
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
|
||||
use_http_json = protocol in {"http/json", "http-json"}
|
||||
|
||||
# Tencent APM works best with delta aggregation temporality
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA}
|
||||
|
||||
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||
"""Create metric exporter with preferred_temporality support"""
|
||||
try:
|
||||
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
|
||||
except Exception:
|
||||
return exporter_cls(**kwargs)
|
||||
|
||||
metric_reader = None
|
||||
if use_http_json:
|
||||
exporter_cls = None
|
||||
for mod_path in (
|
||||
"opentelemetry.exporter.otlp.http.json.metric_exporter",
|
||||
"opentelemetry.exporter.otlp.json.metric_exporter",
|
||||
):
|
||||
try:
|
||||
mod = importlib.import_module(mod_path)
|
||||
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
|
||||
if exporter_cls:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if exporter_cls is not None:
|
||||
metric_exporter = _create_metric_exporter(
|
||||
exporter_cls,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
else:
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as HttpMetricExporter,
|
||||
)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
HttpMetricExporter,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
|
||||
elif use_http_protobuf:
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as HttpMetricExporter,
|
||||
)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
HttpMetricExporter,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
else:
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
||||
OTLPMetricExporter as GrpcMetricExporter,
|
||||
)
|
||||
|
||||
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
GrpcMetricExporter,
|
||||
endpoint=m_grpc_endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
insecure=m_insecure,
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
|
||||
if metric_reader is not None:
|
||||
# Use instance-level MeterProvider instead of global to support config changes
|
||||
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
|
||||
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||
self.meter_provider = provider
|
||||
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
|
||||
|
||||
# LLM operation duration histogram
|
||||
self.hist_llm_duration = self.meter.create_histogram(
|
||||
name=LLM_OPERATION_DURATION,
|
||||
unit="s",
|
||||
description="LLM operation duration (seconds)",
|
||||
)
|
||||
|
||||
# Token usage histogram with exponential buckets
|
||||
self.hist_token_usage = self.meter.create_histogram(
|
||||
name=GEN_AI_TOKEN_USAGE,
|
||||
unit="token",
|
||||
description="Number of tokens used in prompt and completions",
|
||||
)
|
||||
|
||||
# Time to first token histogram
|
||||
self.hist_time_to_first_token = self.meter.create_histogram(
|
||||
name=GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
unit="s",
|
||||
description="Time to first token for streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Time to generate histogram
|
||||
self.hist_time_to_generate = self.meter.create_histogram(
|
||||
name=GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
unit="s",
|
||||
description="Total time to generate streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Trace duration histogram
|
||||
self.hist_trace_duration = self.meter.create_histogram(
|
||||
name=GEN_AI_TRACE_DURATION,
|
||||
unit="s",
|
||||
description="End-to-end GenAI trace duration (seconds)",
|
||||
)
|
||||
|
||||
self.metric_reader = metric_reader
|
||||
else:
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
|
||||
def add_span(self, span_data: SpanData) -> None:
|
||||
"""Create and export span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
self._create_and_export_span(span_data)
|
||||
logger.debug("[Tencent APM] Created span: %s", span_data.name)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
|
||||
|
||||
# Metrics recording API
|
||||
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||
"""Record LLM operation duration histogram in seconds."""
|
||||
try:
|
||||
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
|
||||
return
|
||||
attrs: dict[str, str] = {}
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
LLM_OPERATION_DURATION,
|
||||
latency_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||
|
||||
def record_token_usage(
|
||||
self,
|
||||
token_count: int,
|
||||
token_type: str,
|
||||
operation_name: str,
|
||||
request_model: str,
|
||||
response_model: str,
|
||||
server_address: str,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""Record token usage histogram.
|
||||
|
||||
Args:
|
||||
token_count: Number of tokens used
|
||||
token_type: "input" or "output"
|
||||
operation_name: Operation name (e.g., "chat")
|
||||
request_model: Model used in request
|
||||
response_model: Model used in response
|
||||
server_address: Server address
|
||||
provider: Model provider name
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_token_usage") or self.hist_token_usage is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.request.model": request_model,
|
||||
"gen_ai.response.model": response_model,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.token.type": token_type,
|
||||
"server.address": server_address,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
token_count,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
|
||||
|
||||
def record_time_to_first_token(
|
||||
self, ttft_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to first token histogram.
|
||||
|
||||
Args:
|
||||
ttft_seconds: Time to first token in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_first_token") or self.hist_time_to_first_token is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
ttft_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
|
||||
|
||||
def record_time_to_generate(
|
||||
self, ttg_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to generate histogram.
|
||||
|
||||
Args:
|
||||
ttg_seconds: Time to generate in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_generate") or self.hist_time_to_generate is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
ttg_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
|
||||
|
||||
def record_trace_duration(self, duration_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||
"""Record end-to-end trace duration histogram in seconds.
|
||||
|
||||
Args:
|
||||
duration_seconds: Trace duration in seconds
|
||||
attributes: Optional attributes (e.g., conversation_mode, app_id)
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_trace_duration") or self.hist_trace_duration is None:
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_TRACE_DURATION,
|
||||
duration_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
|
||||
|
||||
def _create_and_export_span(self, span_data: SpanData) -> None:
|
||||
"""Create span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
parent_context = None
|
||||
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
|
||||
parent_context = trace_api.set_span_in_context(
|
||||
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=span_data.name,
|
||||
context=parent_context,
|
||||
kind=SpanKind.INTERNAL,
|
||||
start_time=span_data.start_time,
|
||||
)
|
||||
self.span_contexts[span_data.span_id] = span.get_span_context()
|
||||
|
||||
if span_data.attributes:
|
||||
attributes: dict[str, AttributeValue] = {}
|
||||
for key, value in span_data.attributes.items():
|
||||
if isinstance(value, (int, float, bool)):
|
||||
attributes[key] = value
|
||||
else:
|
||||
attributes[key] = str(value)
|
||||
span.set_attributes(attributes)
|
||||
|
||||
if span_data.events:
|
||||
for event in span_data.events:
|
||||
span.add_event(event.name, event.attributes, event.timestamp)
|
||||
|
||||
if span_data.status:
|
||||
span.set_status(span_data.status)
|
||||
|
||||
# Manually end span; do not use context manager to avoid double-end warnings
|
||||
span.end(end_time=span_data.end_time)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
"""Check API connectivity using socket connection test for gRPC endpoints"""
|
||||
try:
|
||||
# Resolve gRPC target consistently with exporters
|
||||
_, _, host, port = self._resolve_grpc_target(self.endpoint)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(5)
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
|
||||
return True
|
||||
else:
|
||||
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
|
||||
if host in ["127.0.0.1", "localhost"]:
|
||||
logger.info("[Tencent APM] Development environment detected, allowing config save")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] API check failed")
|
||||
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
"""Get project console URL"""
|
||||
return "https://console.cloud.tencent.com/apm"
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the client and export remaining spans"""
|
||||
try:
|
||||
if self.span_processor:
|
||||
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
|
||||
_ = self.span_processor.force_flush()
|
||||
self.span_processor.shutdown()
|
||||
|
||||
if self.tracer_provider:
|
||||
self.tracer_provider.shutdown()
|
||||
|
||||
# Shutdown instance-level meter provider
|
||||
if self.meter_provider is not None:
|
||||
try:
|
||||
self.meter_provider.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
|
||||
|
||||
if self.metric_reader is not None:
|
||||
try:
|
||||
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error during client shutdown")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
|
||||
"""Normalize endpoint to gRPC target and security flag.
|
||||
|
||||
Returns:
|
||||
(grpc_endpoint, insecure, host, port)
|
||||
"""
|
||||
try:
|
||||
if endpoint.startswith(("http://", "https://")):
|
||||
parsed = urlparse(endpoint)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port or default_port
|
||||
insecure = parsed.scheme == "http"
|
||||
return f"{host}:{port}", insecure, host, port
|
||||
|
||||
host = endpoint
|
||||
port = default_port
|
||||
if ":" in endpoint:
|
||||
parts = endpoint.rsplit(":", 1)
|
||||
host = parts[0] or "localhost"
|
||||
try:
|
||||
port = int(parts[1])
|
||||
except Exception:
|
||||
port = default_port
|
||||
|
||||
insecure = ("localhost" in host) or ("127.0.0.1" in host)
|
||||
return f"{host}:{port}", insecure, host, port
|
||||
except Exception:
|
||||
host, port = "localhost", default_port
|
||||
return f"{host}:{port}", True, host, port
|
||||
@ -1 +0,0 @@
|
||||
# Tencent trace entities module
|
||||
@ -1,89 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
|
||||
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "gen_ai.entity.input"
|
||||
|
||||
OUTPUT_VALUE = "gen_ai.entity.output"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# GENERATION
|
||||
GEN_AI_MODEL_NAME = "gen_ai.response.model"
|
||||
|
||||
GEN_AI_PROVIDER = "gen_ai.provider.name"
|
||||
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Streaming Span Attributes
|
||||
GEN_AI_IS_STREAMING_REQUEST = "llm.is_streaming" # Same as OpenLLMetry semconv
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
# Instrumentation Library
|
||||
INSTRUMENTATION_NAME = "dify-sdk"
|
||||
INSTRUMENTATION_VERSION = "0.1.0"
|
||||
INSTRUMENTATION_LANGUAGE = "python"
|
||||
|
||||
|
||||
# Metrics
|
||||
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
|
||||
GEN_AI_TOKEN_USAGE = "gen_ai.client.token.usage"
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN = "gen_ai.server.time_to_first_token"
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE = "gen_ai.streaming.time_to_generate"
|
||||
# The LLM trace duration which is exclusive to tencent apm
|
||||
GEN_AI_TRACE_DURATION = "gen_ai.trace.duration"
|
||||
|
||||
# Token Usage Attributes
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
|
||||
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
GEN_AI_TOKEN_TYPE = "gen_ai.token.type"
|
||||
SERVER_ADDRESS = "server.address"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
WORKFLOW = "WORKFLOW" # OpenLLMetry
|
||||
RETRIEVER = "RETRIEVER" # RAG
|
||||
GENERATION = "GENERATION" # Langfuse
|
||||
TOOL = "TOOL" # OpenLLMetry
|
||||
AGENT = "AGENT" # OpenLLMetry
|
||||
TASK = "TASK" # OpenLLMetry
|
||||
@ -1,21 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int = Field(..., description="The end time of the span in nanoseconds.")
|
||||
@ -1,380 +0,0 @@
|
||||
"""
|
||||
Tencent APM Span Builder - handles all span construction logic
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_IS_STREAMING_REQUEST,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.rag.models.document import Document
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentSpanBuilder:
|
||||
"""Builder class for constructing different types of spans"""
|
||||
|
||||
@staticmethod
|
||||
def _get_time_nanoseconds(time_value: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds for span creation."""
|
||||
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_spans(
|
||||
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||
) -> list[SpanData]:
|
||||
"""Build workflow-related spans"""
|
||||
spans = []
|
||||
links = links or []
|
||||
|
||||
message_span_id = None
|
||||
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
|
||||
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
|
||||
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
|
||||
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
if message_span_id:
|
||||
message_span = TencentSpanBuilder._build_message_span(
|
||||
trace_info, trace_id, message_span_id, user_id, status, links
|
||||
)
|
||||
spans.append(message_span)
|
||||
|
||||
workflow_span = TencentSpanBuilder._build_workflow_span(
|
||||
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
|
||||
)
|
||||
spans.append(workflow_span)
|
||||
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _build_message_span(
|
||||
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
|
||||
) -> SpanData:
|
||||
"""Build message span for chatflow"""
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_span(
|
||||
trace_info: WorkflowTraceInfo,
|
||||
trace_id: int,
|
||||
workflow_span_id: int,
|
||||
message_span_id: int | None,
|
||||
user_id: str,
|
||||
status: Status,
|
||||
links: list,
|
||||
) -> SpanData:
|
||||
"""Build workflow span"""
|
||||
attributes = {
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
}
|
||||
|
||||
if message_span_id is None:
|
||||
attributes[GEN_AI_IS_ENTRY] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_llm_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build LLM span for workflow nodes."""
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
}
|
||||
|
||||
if usage_data.get("time_to_first_token") is not None:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes=attributes,
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||
) -> SpanData:
|
||||
"""Build message span."""
|
||||
links = links or []
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
|
||||
name=trace_info.tool_name,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: "",
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build dataset retrieval span."""
|
||||
status = Status(StatusCode.OK)
|
||||
if getattr(trace_info, "error", None):
|
||||
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
|
||||
|
||||
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
|
||||
name="retrieval",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
"""Get workflow node execution status."""
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return Status(StatusCode.UNSET)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_retrieval_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build knowledge retrieval span for workflow nodes."""
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_tool_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build tool span for workflow nodes."""
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_task_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build generic task span for workflow nodes."""
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_retrieval_documents(documents: list[Document]):
|
||||
"""Extract documents data for retrieval tracing."""
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
@ -1,522 +0,0 @@
|
||||
"""
|
||||
Tencent APM tracing implementation with separated concerns
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.client import TencentTraceClient
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentDataTrace(BaseTraceInstance):
|
||||
"""
|
||||
Tencent APM trace implementation with single responsibility principle.
|
||||
Acts as a coordinator that delegates specific tasks to specialized classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tencent_config: TencentConfig):
|
||||
super().__init__(tencent_config)
|
||||
self.trace_client = TencentTraceClient(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
token=tencent_config.token,
|
||||
metrics_export_interval_sec=5,
|
||||
)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
"""Main tracing entry point - coordinates different trace types."""
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self) -> bool:
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return self.trace_client.get_project_url()
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
|
||||
"""Handle workflow tracing by coordinating data retrieval and span construction."""
|
||||
try:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
|
||||
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
user_id = self._get_user_id(trace_info)
|
||||
|
||||
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
for span in workflow_spans:
|
||||
self.trace_client.add_span(span)
|
||||
|
||||
self._process_workflow_nodes(trace_info, trace_id)
|
||||
|
||||
# Record trace duration for entry span
|
||||
self._record_workflow_trace_duration(trace_info)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process workflow trace")
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Handle message tracing."""
|
||||
try:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
|
||||
user_id = self._get_user_id(trace_info)
|
||||
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
# Record trace duration for entry span
|
||||
self._record_message_trace_duration(trace_info)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process message trace")
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
|
||||
"""Handle tool tracing."""
|
||||
try:
|
||||
parent_span_id = None
|
||||
trace_root_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
trace_root_id = trace_info.message_id
|
||||
|
||||
if parent_span_id and trace_root_id:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||
|
||||
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
|
||||
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process tool trace")
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
|
||||
"""Handle dataset retrieval tracing."""
|
||||
try:
|
||||
parent_span_id = None
|
||||
trace_root_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
trace_root_id = trace_info.message_id
|
||||
|
||||
if parent_span_id and trace_root_id:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||
|
||||
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
|
||||
|
||||
self.trace_client.add_span(retrieval_span)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
|
||||
"""Handle suggested question tracing"""
|
||||
try:
|
||||
logger.info("[Tencent APM] Processing suggested question trace")
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process suggested question trace")
|
||||
|
||||
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
|
||||
"""Process workflow node executions."""
|
||||
try:
|
||||
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
|
||||
node_executions = self._get_workflow_node_executions(trace_info)
|
||||
|
||||
for node_execution in node_executions:
|
||||
try:
|
||||
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
if node_span:
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
self._record_llm_metrics(node_execution)
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def _build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
) -> SpanData | None:
|
||||
"""Build span for different node types"""
|
||||
try:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
return TencentSpanBuilder.build_workflow_llm_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
return TencentSpanBuilder.build_workflow_retrieval_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
return TencentSpanBuilder.build_workflow_tool_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
else:
|
||||
# Handle all other node types as generic tasks
|
||||
return TencentSpanBuilder.build_workflow_task_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[Tencent APM] Error building span for node %s: %s",
|
||||
node_execution.id,
|
||||
node_execution.node_type,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
|
||||
"""Retrieve workflow node executions from database."""
|
||||
try:
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||
.limit(1)
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_maker,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id)
|
||||
return list(executions)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to get workflow node executions")
|
||||
return []
|
||||
|
||||
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
|
||||
"""Get user ID from trace info."""
|
||||
try:
|
||||
tenant_id = None
|
||||
user_id = None
|
||||
|
||||
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
|
||||
tenant_id = trace_info.tenant_id
|
||||
|
||||
if hasattr(trace_info, "metadata") and trace_info.metadata:
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
|
||||
if user_id and tenant_id:
|
||||
stmt = (
|
||||
select(Account.name)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
|
||||
)
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
with session_maker() as session:
|
||||
account_name = session.scalar(stmt)
|
||||
return account_name or str(user_id)
|
||||
elif user_id:
|
||||
return str(user_id)
|
||||
|
||||
return "anonymous"
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to get user ID")
|
||||
return "unknown"
|
||||
|
||||
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
|
||||
"""Record LLM performance metrics"""
|
||||
try:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
model_provider = process_data.get("model_provider", "unknown")
|
||||
model_name = process_data.get("model_name", "unknown")
|
||||
model_mode = process_data.get("model_mode", "chat")
|
||||
|
||||
# Record LLM duration
|
||||
if hasattr(self.trace_client, "record_llm_duration"):
|
||||
latency_s = float(usage.get("latency", 0.0))
|
||||
|
||||
if latency_s > 0:
|
||||
# Determine if streaming from usage metrics
|
||||
is_streaming = usage.get("time_to_first_token") is not None
|
||||
|
||||
attributes = {
|
||||
"gen_ai.system": model_provider,
|
||||
"gen_ai.response.model": model_name,
|
||||
"gen_ai.operation.name": model_mode,
|
||||
"stream": "true" if is_streaming else "false",
|
||||
}
|
||||
self.trace_client.record_llm_duration(latency_s, attributes)
|
||||
|
||||
# Record streaming metrics from usage
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
if time_to_first_token is not None and hasattr(self.trace_client, "record_time_to_first_token"):
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
if ttft_seconds > 0:
|
||||
self.trace_client.record_time_to_first_token(
|
||||
ttft_seconds=ttft_seconds, provider=model_provider, model=model_name, operation_name=model_mode
|
||||
)
|
||||
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
if time_to_generate is not None and hasattr(self.trace_client, "record_time_to_generate"):
|
||||
ttg_seconds = float(time_to_generate)
|
||||
if ttg_seconds > 0:
|
||||
self.trace_client.record_time_to_generate(
|
||||
ttg_seconds=ttg_seconds, provider=model_provider, model=model_name, operation_name=model_mode
|
||||
)
|
||||
|
||||
# Record token usage
|
||||
if hasattr(self.trace_client, "record_token_usage"):
|
||||
# Extract token counts
|
||||
input_tokens = int(usage.get("prompt_tokens", 0))
|
||||
output_tokens = int(usage.get("completion_tokens", 0))
|
||||
|
||||
if input_tokens > 0 or output_tokens > 0:
|
||||
server_address = f"{model_provider}"
|
||||
|
||||
# Record input tokens
|
||||
if input_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=input_tokens,
|
||||
token_type="input",
|
||||
operation_name=model_mode,
|
||||
request_model=model_name,
|
||||
response_model=model_name,
|
||||
server_address=server_address,
|
||||
provider=model_provider,
|
||||
)
|
||||
|
||||
# Record output tokens
|
||||
if output_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=output_tokens,
|
||||
token_type="output",
|
||||
operation_name=model_mode,
|
||||
request_model=model_name,
|
||||
response_model=model_name,
|
||||
server_address=server_address,
|
||||
provider=model_provider,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM metrics")
|
||||
|
||||
def _record_message_llm_metrics(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Record LLM metrics for message traces"""
|
||||
try:
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
provider_latency = 0.0
|
||||
if isinstance(message_data, dict):
|
||||
provider_latency = float(message_data.get("provider_response_latency", 0.0) or 0.0)
|
||||
else:
|
||||
provider_latency = float(getattr(message_data, "provider_response_latency", 0.0) or 0.0)
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
# Record LLM duration
|
||||
if provider_latency > 0 and hasattr(self.trace_client, "record_llm_duration"):
|
||||
is_streaming = trace_info.is_streaming_request
|
||||
|
||||
duration_attributes = {
|
||||
"gen_ai.system": model_provider,
|
||||
"gen_ai.response.model": model_name,
|
||||
"gen_ai.operation.name": "chat", # Message traces are always chat
|
||||
"stream": "true" if is_streaming else "false",
|
||||
}
|
||||
self.trace_client.record_llm_duration(provider_latency, duration_attributes)
|
||||
|
||||
# Record streaming metrics for message traces
|
||||
if trace_info.is_streaming_request:
|
||||
# Record time to first token
|
||||
if trace_info.gen_ai_server_time_to_first_token is not None and hasattr(
|
||||
self.trace_client, "record_time_to_first_token"
|
||||
):
|
||||
ttft_seconds = float(trace_info.gen_ai_server_time_to_first_token)
|
||||
if ttft_seconds > 0:
|
||||
self.trace_client.record_time_to_first_token(
|
||||
ttft_seconds=ttft_seconds, provider=str(model_provider or ""), model=str(model_name or "")
|
||||
)
|
||||
|
||||
# Record time to generate
|
||||
if trace_info.llm_streaming_time_to_generate is not None and hasattr(
|
||||
self.trace_client, "record_time_to_generate"
|
||||
):
|
||||
ttg_seconds = float(trace_info.llm_streaming_time_to_generate)
|
||||
if ttg_seconds > 0:
|
||||
self.trace_client.record_time_to_generate(
|
||||
ttg_seconds=ttg_seconds, provider=str(model_provider or ""), model=str(model_name or "")
|
||||
)
|
||||
|
||||
# Record token usage
|
||||
if hasattr(self.trace_client, "record_token_usage"):
|
||||
input_tokens = int(trace_info.message_tokens or 0)
|
||||
output_tokens = int(trace_info.answer_tokens or 0)
|
||||
|
||||
if input_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=input_tokens,
|
||||
token_type="input",
|
||||
operation_name="chat",
|
||||
request_model=str(model_name or ""),
|
||||
response_model=str(model_name or ""),
|
||||
server_address=str(model_provider or ""),
|
||||
provider=str(model_provider or ""),
|
||||
)
|
||||
|
||||
if output_tokens > 0:
|
||||
self.trace_client.record_token_usage(
|
||||
token_count=output_tokens,
|
||||
token_type="output",
|
||||
operation_name="chat",
|
||||
request_model=str(model_name or ""),
|
||||
response_model=str(model_name or ""),
|
||||
server_address=str(model_provider or ""),
|
||||
provider=str(model_provider or ""),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message LLM metrics")
|
||||
|
||||
def _record_workflow_trace_duration(self, trace_info: WorkflowTraceInfo) -> None:
|
||||
"""Record end-to-end workflow trace duration."""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_trace_duration"):
|
||||
return
|
||||
|
||||
# Calculate duration from start_time and end_time to match span duration
|
||||
if trace_info.start_time and trace_info.end_time:
|
||||
duration_s = (trace_info.end_time - trace_info.start_time).total_seconds()
|
||||
else:
|
||||
# Fallback to workflow_run_elapsed_time if timestamps not available
|
||||
duration_s = float(trace_info.workflow_run_elapsed_time)
|
||||
|
||||
if duration_s > 0:
|
||||
attributes = {
|
||||
"conversation_mode": "workflow",
|
||||
"workflow_status": trace_info.workflow_run_status,
|
||||
}
|
||||
|
||||
# Add conversation_id if available
|
||||
if trace_info.conversation_id:
|
||||
attributes["has_conversation"] = "true"
|
||||
else:
|
||||
attributes["has_conversation"] = "false"
|
||||
|
||||
self.trace_client.record_trace_duration(duration_s, attributes)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record workflow trace duration")
|
||||
|
||||
def _record_message_trace_duration(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Record end-to-end message trace duration."""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_trace_duration"):
|
||||
return
|
||||
|
||||
# Calculate duration from start_time and end_time
|
||||
if trace_info.start_time and trace_info.end_time:
|
||||
duration = (trace_info.end_time - trace_info.start_time).total_seconds()
|
||||
|
||||
if duration > 0:
|
||||
attributes = {
|
||||
"conversation_mode": trace_info.conversation_mode,
|
||||
}
|
||||
|
||||
# Add streaming flag if available
|
||||
if hasattr(trace_info, "is_streaming_request"):
|
||||
attributes["stream"] = "true" if trace_info.is_streaming_request else "false"
|
||||
|
||||
self.trace_client.record_trace_duration(duration, attributes)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message trace duration")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure proper cleanup on garbage collection."""
|
||||
try:
|
||||
if hasattr(self, "trace_client"):
|
||||
self.trace_client.shutdown()
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
@ -1,64 +0,0 @@
|
||||
"""
|
||||
Utility functions for Tencent APM tracing
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
|
||||
class TencentTraceUtils:
|
||||
"""Utility class for common tracing operations."""
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
@staticmethod
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
return uuid_obj.int
|
||||
|
||||
@staticmethod
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
@staticmethod
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
|
||||
if start_time is None:
|
||||
start_time = datetime.now()
|
||||
timestamp_in_seconds = start_time.timestamp()
|
||||
return int(timestamp_in_seconds * 1e9)
|
||||
|
||||
@staticmethod
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int
|
||||
except (ValueError, TypeError):
|
||||
trace_id = uuid.uuid4().int
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=TencentTraceUtils.INVALID_SPAN_ID,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
return Link(span_context)
|
||||
@ -1,99 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace")
|
||||
attributes: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
exception: str | None = Field(None, description="Exception message of the trace")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
@ -1,521 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from weave.trace_server.trace_server_interface import (
|
||||
CallEndReq,
|
||||
CallStartReq,
|
||||
EndedCallSchemaForInsert,
|
||||
StartedCallSchemaForInsert,
|
||||
SummaryInsertMap,
|
||||
TraceStatus,
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaveDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
weave_config: WeaveConfig,
|
||||
):
|
||||
super().__init__(weave_config)
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
|
||||
# Then initialize weave client
|
||||
self.weave_client = weave.init(
|
||||
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.debug("Trace info: %s", trace_info)
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
|
||||
if trace_info.message_id:
|
||||
message_attributes = trace_info.metadata
|
||||
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
message_attributes["message_id"] = trace_info.message_id
|
||||
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
message_attributes["trace_id"] = trace_id
|
||||
message_attributes["start_time"] = trace_info.start_time
|
||||
message_attributes["end_time"] = trace_info.end_time
|
||||
message_attributes["tags"] = ["message", "workflow"]
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
attributes=message_attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(message_run)
|
||||
|
||||
workflow_attributes = trace_info.metadata
|
||||
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
workflow_attributes["trace_id"] = trace_id
|
||||
workflow_attributes["start_time"] = trace_info.start_time
|
||||
workflow_attributes["end_time"] = trace_info.end_time
|
||||
workflow_attributes["tags"] = ["dify_workflow"]
|
||||
|
||||
workflow_run = WeaveTraceModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
# rearrange workflow_node_executions by starting time
|
||||
workflow_node_executions = sorted(workflow_node_executions, key=lambda x: x.created_at)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
attributes["tags"] = ["node_execution"]
|
||||
attributes["start_time"] = created_at
|
||||
attributes["end_time"] = finished_at
|
||||
attributes["elapsed_time"] = elapsed_time
|
||||
attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
attributes["trace_id"] = trace_id
|
||||
node_run = WeaveTraceModel(
|
||||
total_tokens=node_total_tokens,
|
||||
op=node_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
attributes=attributes,
|
||||
id=node_execution_id,
|
||||
exception=None,
|
||||
)
|
||||
|
||||
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(node_run)
|
||||
|
||||
self.finish_call(workflow_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
attributes = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
attributes["end_user_id"] = end_user_id
|
||||
|
||||
attributes["message_id"] = message_id
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
exception=trace_info.error,
|
||||
file_list=file_list,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.start_call(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
op="llm",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
file_list=[],
|
||||
exception=None,
|
||||
)
|
||||
self.start_call(
|
||||
llm_run,
|
||||
parent_run_id=trace_id,
|
||||
)
|
||||
self.finish_call(llm_run)
|
||||
self.finish_call(message_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["moderation"]
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(moderation_run, parent_run_id=trace_id)
|
||||
self.finish_call(moderation_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["suggested_question"]
|
||||
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_id)
|
||||
self.finish_call(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["dataset_retrieval"]
|
||||
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_id)
|
||||
self.finish_call(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["tool", trace_info.tool_name]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
tool_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
self.start_call(tool_run, parent_run_id=trace_id)
|
||||
self.finish_call(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["generate_name"]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(name_run)
|
||||
self.finish_call(name_run)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def _normalize_time(self, dt: datetime | None) -> datetime:
|
||||
if dt is None:
|
||||
return datetime.now(UTC)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
|
||||
if trace_id is None:
|
||||
trace_id = run_data.id
|
||||
|
||||
call_start_req = CallStartReq(
|
||||
start=StartedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
op_name=str(run_data.op),
|
||||
trace_id=trace_id,
|
||||
parent_id=parent_run_id,
|
||||
started_at=started_at,
|
||||
attributes=attributes,
|
||||
inputs=inputs,
|
||||
wb_user_id=None,
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_start(call_start_req)
|
||||
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call_meta = self.calls.get(run_data.id)
|
||||
if not call_meta:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
|
||||
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
|
||||
if elapsed_ms < 0:
|
||||
elapsed_ms = 0
|
||||
|
||||
status_counts = {
|
||||
TraceStatus.SUCCESS: 0,
|
||||
TraceStatus.ERROR: 0,
|
||||
}
|
||||
if run_data.exception:
|
||||
status_counts[TraceStatus.ERROR] = 1
|
||||
else:
|
||||
status_counts[TraceStatus.SUCCESS] = 1
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"status_counts": status_counts,
|
||||
"weave": {"latency_ms": elapsed_ms},
|
||||
}
|
||||
|
||||
exception_str = str(run_data.exception) if run_data.exception else None
|
||||
|
||||
call_end_req = CallEndReq(
|
||||
end=EndedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
ended_at=ended_at,
|
||||
exception=exception_str,
|
||||
output=run_data.outputs,
|
||||
summary=cast(SummaryInsertMap, summary),
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_end(call_end_req)
|
||||
Reference in New Issue
Block a user