feat: extract mcp tool usage (#31802)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-02-09 09:52:14 +08:00
committed by GitHub
parent aa800d838d
commit 483db22b97
2 changed files with 332 additions and 2 deletions

View File

@ -3,8 +3,8 @@ from __future__ import annotations
import base64
import json
import logging
from collections.abc import Generator
from typing import Any
from collections.abc import Generator, Mapping
from typing import Any, cast
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
@ -17,6 +17,7 @@ from core.mcp.types import (
TextContent,
TextResourceContents,
)
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@ -46,6 +47,7 @@ class MCPTool(Tool):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self._latest_usage = LLMUsage.empty_usage()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@ -59,6 +61,10 @@ class MCPTool(Tool):
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
@ -120,6 +126,99 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
@property
def latest_usage(self) -> LLMUsage:
return self._latest_usage
@classmethod
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
"""
Extract usage metadata from MCP tool result's _meta field.
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
usage information such as token counts, costs, and other metadata.
Args:
result: The CallToolResult from MCP tool invocation
Returns:
LLMUsage instance with values from meta or empty_usage if not found
"""
# Extract usage from the meta field if present
if result.meta:
usage_dict = cls._extract_usage_dict(result.meta)
if usage_dict is not None:
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
return LLMUsage.empty_usage()
@classmethod
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
"""
Recursively search for usage dictionary in the payload.
The MCP protocol's _meta field can contain usage data in various formats:
- Direct usage field: {"usage": {...}}
- Nested in metadata: {"metadata": {"usage": {...}}}
- Or nested within other fields
Args:
payload: The payload to search for usage data
Returns:
The usage dictionary if found, None otherwise
"""
# Check for direct usage field
usage_candidate = payload.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for metadata nested usage
metadata_candidate = payload.get("metadata")
if isinstance(metadata_candidate, Mapping):
usage_candidate = metadata_candidate.get("usage")
if isinstance(usage_candidate, Mapping):
return usage_candidate
# Check for common token counting fields directly in payload
# Some MCP servers may include token counts directly
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
usage_dict: dict[str, Any] = {}
for key in (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"prompt_unit_price",
"completion_unit_price",
"total_price",
"currency",
"prompt_price_unit",
"completion_price_unit",
"prompt_price",
"completion_price",
"latency",
"time_to_first_token",
"time_to_generate",
):
if key in payload:
usage_dict[key] = payload[key]
if usage_dict:
return usage_dict
# Recursively search through nested structures
for value in payload.values():
if isinstance(value, Mapping):
found = cls._extract_usage_dict(value)
if found is not None:
return found
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Mapping):
found = cls._extract_usage_dict(item)
if found is not None:
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,