mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
feat: support ttft report to langfuse (#33344)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@staticmethod
|
||||
def _get_completion_start_time(
|
||||
start_time: datetime | None, time_to_first_token: float | int | None
|
||||
) -> datetime | None:
|
||||
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
|
||||
if start_time is None or time_to_first_token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if ttft_seconds < 0:
|
||||
return None
|
||||
|
||||
return start_time + timedelta(seconds=ttft_seconds)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_start_time = None
|
||||
try:
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = process_data.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = outputs.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = {}
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
created_at, usage_data.get("time_to_first_token")
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
start_time=created_at,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
trace_info.start_time,
|
||||
trace_info.gen_ai_server_time_to_first_token,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
input=trace_info.inputs,
|
||||
|
||||
137
api/tests/unit_tests/core/ops/test_langfuse_trace.py
Normal file
137
api/tests/unit_tests/core/ops/test_langfuse_trace.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Tests for Langfuse TTFT reporting support."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
|
||||
def _create_trace_instance() -> LangFuseDataTrace:
|
||||
with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True):
|
||||
return LangFuseDataTrace(
|
||||
LangfuseConfig(
|
||||
public_key="public-key",
|
||||
secret_key="secret-key",
|
||||
host="https://cloud.langfuse.com",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestLangFuseDataTraceCompletionStartTime:
|
||||
def test_message_trace_reports_completion_start_time(self):
|
||||
trace = _create_trace_instance()
|
||||
start_time = datetime(2026, 3, 11, 13, 0, 0)
|
||||
trace_info = MessageTraceInfo(
|
||||
trace_id="trace-123",
|
||||
message_id="message-123",
|
||||
message_data=SimpleNamespace(
|
||||
id="message-123",
|
||||
from_account_id="account-1",
|
||||
from_end_user_id=None,
|
||||
conversation_id="conversation-1",
|
||||
model_id="gpt-4o-mini",
|
||||
answer="hi there",
|
||||
status="normal",
|
||||
error="",
|
||||
total_price=0.12,
|
||||
provider_response_latency=3.5,
|
||||
),
|
||||
conversation_model="chat",
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
error="",
|
||||
inputs="hello",
|
||||
outputs="hi there",
|
||||
file_list=[],
|
||||
start_time=start_time,
|
||||
end_time=start_time + timedelta(seconds=3.5),
|
||||
metadata={},
|
||||
message_file_data=None,
|
||||
conversation_mode="chat",
|
||||
gen_ai_server_time_to_first_token=1.2,
|
||||
llm_streaming_time_to_generate=2.3,
|
||||
is_streaming_request=True,
|
||||
)
|
||||
|
||||
with patch.object(trace, "add_trace"), patch.object(trace, "add_generation") as add_generation:
|
||||
trace.message_trace(trace_info)
|
||||
|
||||
generation = add_generation.call_args.args[0]
|
||||
assert generation.completion_start_time == start_time + timedelta(seconds=1.2)
|
||||
|
||||
def test_workflow_trace_reports_completion_start_time_from_llm_usage(self):
|
||||
trace = _create_trace_instance()
|
||||
start_time = datetime(2026, 3, 11, 13, 0, 0)
|
||||
node_execution = SimpleNamespace(
|
||||
id="node-exec-1",
|
||||
title="Chat LLM",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
status="succeeded",
|
||||
process_data={
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4o-mini",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"time_to_first_token": 1.2,
|
||||
},
|
||||
},
|
||||
inputs={"question": "hello"},
|
||||
outputs={"text": "hi there"},
|
||||
created_at=start_time,
|
||||
elapsed_time=3.5,
|
||||
metadata={},
|
||||
)
|
||||
trace_info = WorkflowTraceInfo(
|
||||
trace_id="trace-123",
|
||||
workflow_data={},
|
||||
conversation_id=None,
|
||||
workflow_app_log_id=None,
|
||||
workflow_id="workflow-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="workflow-run-1",
|
||||
workflow_run_elapsed_time=3.5,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={"question": "hello"},
|
||||
workflow_run_outputs={"answer": "hi there"},
|
||||
workflow_run_version="1",
|
||||
error="",
|
||||
total_tokens=30,
|
||||
file_list=[],
|
||||
query="hello",
|
||||
metadata={"app_id": "app-1", "user_id": "user-1"},
|
||||
start_time=start_time,
|
||||
end_time=start_time + timedelta(seconds=3.5),
|
||||
)
|
||||
repository = MagicMock()
|
||||
repository.get_by_workflow_execution.return_value = [node_execution]
|
||||
|
||||
with (
|
||||
patch.object(trace, "add_trace"),
|
||||
patch.object(trace, "add_span"),
|
||||
patch.object(trace, "add_generation") as add_generation,
|
||||
patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()),
|
||||
patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()),
|
||||
patch(
|
||||
"core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=repository,
|
||||
),
|
||||
):
|
||||
trace.workflow_trace(trace_info)
|
||||
|
||||
generation = add_generation.call_args.kwargs["langfuse_generation_data"]
|
||||
assert generation.completion_start_time == start_time + timedelta(seconds=1.2)
|
||||
|
||||
def test_ignores_invalid_ttft_values(self):
|
||||
trace = _create_trace_instance()
|
||||
start_time = datetime(2026, 3, 11, 13, 0, 0)
|
||||
|
||||
assert trace._get_completion_start_time(start_time, None) is None
|
||||
assert trace._get_completion_start_time(start_time, -1) is None
|
||||
assert trace._get_completion_start_time(start_time, "invalid") is None
|
||||
@ -1702,7 +1702,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match=""):
|
||||
with pytest.raises(ValueError):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@ -131,9 +131,12 @@ class TestMessageServicePaginationByFirstId:
|
||||
assert result.has_more is False
|
||||
|
||||
# Test 03: Basic pagination without first_id (desc order)
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_desc(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_without_first_id_desc(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory
|
||||
):
|
||||
"""Test basic pagination without first_id in descending order."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -171,9 +174,12 @@ class TestMessageServicePaginationByFirstId:
|
||||
assert result.data[0].id == "msg-000"
|
||||
|
||||
# Test 04: Basic pagination without first_id (asc order)
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_without_first_id_asc(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_without_first_id_asc(
|
||||
self, mock_conversation_service, mock_db, mock_create_repo, factory
|
||||
):
|
||||
"""Test basic pagination without first_id in ascending order."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -211,9 +217,10 @@ class TestMessageServicePaginationByFirstId:
|
||||
assert result.data[4].id == "msg-000"
|
||||
|
||||
# Test 05: Pagination with first_id
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_conversation_service, mock_db, mock_create_repo, factory):
|
||||
"""Test pagination with first_id to get messages before a specific message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -278,9 +285,10 @@ class TestMessageServicePaginationByFirstId:
|
||||
)
|
||||
|
||||
# Test 07: Has_more flag when results exceed limit
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db")
|
||||
@patch("services.message_service.ConversationService")
|
||||
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, factory):
|
||||
def test_pagination_by_first_id_has_more_true(self, mock_conversation_service, mock_db, mock_create_repo, factory):
|
||||
"""Test has_more flag is True when results exceed limit."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
Reference in New Issue
Block a user