[frontend][gptoss] Add per turn stats into Harmony Context (#25061)

Signed-off-by: lacora <hyelacora@gmail.com>
Co-authored-by: Ye Hu <yehu@fb.com>
This commit is contained in:
Ye Hu
2025-10-14 16:48:13 -07:00
committed by GitHub
parent 7e0ef4084a
commit 0512c04aee
4 changed files with 188 additions and 62 deletions

View File

@ -6,7 +6,11 @@ from unittest.mock import MagicMock, patch
import pytest
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
from vllm.entrypoints.context import (
HarmonyContext,
StreamingHarmonyContext,
TurnMetrics,
)
from vllm.outputs import CompletionOutput, RequestOutput
@ -101,8 +105,12 @@ def test_single_turn_token_counting():
# Verify internal state tracking
assert not context.is_first_turn
assert context.previous_turn.input_tokens == 5
assert context.previous_turn.output_tokens == 3
assert len(context.all_turn_metrics) == 1
previous_turn = context.all_turn_metrics[0]
assert previous_turn.input_tokens == 5
assert previous_turn.output_tokens == 3
assert previous_turn.cached_input_tokens == 2
assert previous_turn.tool_output_tokens == 0
@pytest.mark.asyncio
@ -156,6 +164,15 @@ async def test_multi_turn_token_counting():
assert context.num_tool_output_tokens == expected_tool_output
assert context.num_cached_tokens == 5 + 15
# Validate all turn metrics
assert len(context.all_turn_metrics) == 3
for i, turn in enumerate(context.all_turn_metrics):
assert turn.input_tokens == prompt_token_counts[i]
assert turn.output_tokens == output_token_counts[i]
assert turn.cached_input_tokens == cached_token_counts[i]
assert context.all_turn_metrics[1].tool_output_tokens == 7
assert context.all_turn_metrics[2].tool_output_tokens == 1
def test_empty_output_tokens():
"""Test behavior when RequestOutput has empty output tokens."""
@ -314,6 +331,10 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
# Create a streaming context
context = StreamingHarmonyContext(messages=[], available_tools=["browser"])
num_prompt_tokens = [3, 8, 13]
num_output_tokens = [3, 3, 2]
num_cached_tokens = [0, 3, 8]
# Simulate three turns of conversation:
# Turn 1: stream tokens one by one, then finish the message
# Turn 2: new prompt, stream more tokens with a reasoning segment
@ -325,7 +346,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
create_mock_request_output(
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
output_token_ids=[101], # Single token
num_cached_tokens=0,
num_cached_tokens=num_cached_tokens[0],
finished=False, # Not end of message yet
)
)
@ -370,7 +391,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
5,
], # 8 tokens (includes previous)
output_token_ids=[201],
num_cached_tokens=3, # Some tokens cached
num_cached_tokens=num_cached_tokens[1], # Some tokens cached
finished=False,
)
)
@ -422,7 +443,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
7,
], # 13 tokens
output_token_ids=[301],
num_cached_tokens=8, # More cached tokens
num_cached_tokens=num_cached_tokens[2], # More cached tokens
finished=False,
)
)
@ -435,10 +456,12 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
)
# Final token counts check
assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts
assert context.num_output_tokens == 3 + 3 + 2 # All outputs
assert context.num_prompt_tokens == sum(num_prompt_tokens) # All prompts
assert context.num_output_tokens == sum(num_output_tokens) # All outputs
assert context.num_reasoning_tokens == 3 # Unchanged from second turn
assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens
assert context.num_cached_tokens == sum(
num_cached_tokens
) # Accumulated cached tokens
# Additional tool tokens from third turn
# Formula: this turn prompt - last turn prompt - last turn output
@ -447,6 +470,15 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens
)
# Validate all turn metrics
assert len(context.all_turn_metrics) == 3
for i, turn in enumerate(context.all_turn_metrics):
assert turn.input_tokens == num_prompt_tokens[i]
assert turn.output_tokens == num_output_tokens[i]
assert turn.cached_input_tokens == num_cached_tokens[i]
assert context.all_turn_metrics[1].tool_output_tokens == 2
assert context.all_turn_metrics[2].tool_output_tokens == 2
@pytest.mark.asyncio
async def test_streaming_message_synchronization(mock_parser):
@ -522,3 +554,46 @@ async def test_streaming_message_synchronization(mock_parser):
assert len(context._messages) == 3
assert context.num_init_messages == 1
assert context._messages[2].content[0].text == "Response 4"
def test_turn_metrics_copy_and_reset():
"""Test TurnMetrics copy and reset methods work correctly."""
# Create a TurnMetrics with specific values
original_metrics = TurnMetrics(
input_tokens=10,
output_tokens=20,
cached_input_tokens=5,
tool_output_tokens=3,
)
# Test copy functionality
copied_metrics = original_metrics.copy()
# Verify copy has same values
assert copied_metrics.input_tokens == 10
assert copied_metrics.output_tokens == 20
assert copied_metrics.cached_input_tokens == 5
assert copied_metrics.tool_output_tokens == 3
# Verify they are separate objects
assert copied_metrics is not original_metrics
# Modify copy to ensure independence
copied_metrics.input_tokens = 999
assert original_metrics.input_tokens == 10 # Original unchanged
assert copied_metrics.input_tokens == 999
# Test reset functionality
original_metrics.reset()
# Verify all fields are reset to zero
assert original_metrics.input_tokens == 0
assert original_metrics.output_tokens == 0
assert original_metrics.cached_input_tokens == 0
assert original_metrics.tool_output_tokens == 0
# Verify copied metrics are unaffected by reset
assert copied_metrics.input_tokens == 999
assert copied_metrics.output_tokens == 20
assert copied_metrics.cached_input_tokens == 5
assert copied_metrics.tool_output_tokens == 3