diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py new file mode 100644 index 0000000000..6292306e7c --- /dev/null +++ b/vllm/entrypoints/context.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from abc import ABC, abstractmethod + +from openai_harmony import Message, Role, StreamState + +from vllm.entrypoints.harmony_utils import ( + get_encoding, get_streamable_parser_for_assistant, render_for_completion) +from vllm.entrypoints.tool import Tool +from vllm.outputs import RequestOutput + +logger = logging.getLogger(__name__) + + +class ConversationContext(ABC): + + @abstractmethod + def append_output(self, output) -> None: + pass + + @abstractmethod + async def call_tool(self) -> list[Message]: + pass + + @abstractmethod + def need_builtin_tool_call(self) -> bool: + pass + + @abstractmethod + def render_for_completion(self) -> list[int]: + pass + + +class SimpleContext(ConversationContext): + + def __init__(self): + self.last_output = None + + def append_output(self, output) -> None: + self.last_output = output + + def need_builtin_tool_call(self) -> bool: + return False + + async def call_tool(self) -> list[Message]: + raise NotImplementedError("Should not be called.") + + def render_for_completion(self) -> list[int]: + raise NotImplementedError("Should not be called.") + + +class HarmonyContext(ConversationContext): + + def __init__( + self, + messages: list, + tool_sessions: dict[str, Tool], + ): + self._messages = messages + self.tool_sessions = tool_sessions + + self.parser = get_streamable_parser_for_assistant() + self.num_init_messages = len(messages) + # TODO(woosuk): Implement the following fields. + self.num_prompt_tokens = 0 + self.num_cached_tokens = 0 + self.num_output_tokens = 0 + self.num_reasoning_tokens = 0 + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + output_token_ids = output.outputs[0].token_ids + for token_id in output_token_ids: + self.parser.process(token_id) + output_msgs = self.parser.messages + else: + # Tool output. + output_msgs = output + self._messages.extend(output_msgs) + + @property + def messages(self) -> list: + return self._messages + + def need_builtin_tool_call(self) -> bool: + last_msg = self.messages[-1] + recipient = last_msg.recipient + return recipient is not None and (recipient.startswith("browser.") + or recipient.startswith("python")) + + async def call_tool(self) -> list[Message]: + if not self.messages: + return [] + last_msg = self.messages[-1] + recipient = last_msg.recipient + if recipient is not None: + if recipient.startswith("browser."): + return await self.call_search_tool( + self.tool_sessions["browser"], last_msg) + elif recipient.startswith("python"): + return await self.call_python_tool( + self.tool_sessions["python"], last_msg) + raise ValueError("No tool call found") + + def render_for_completion(self) -> list[int]: + return render_for_completion(self.messages) + + async def call_search_tool( + self, + tool_session: Tool, + last_msg: Message, + ) -> list[Message]: + return await tool_session.get_result(self) + + async def call_python_tool( + self, + tool_session: Tool, + last_msg: Message, + ) -> list[Message]: + return await tool_session.get_result(self) + + +class StreamingHarmonyContext(HarmonyContext): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_output = None + + self.parser = get_streamable_parser_for_assistant() + self.encoding = get_encoding() + self.last_tok = None + + @property + def messages(self) -> list: + return self.parser.messages + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + tok = output.outputs[0].token_ids[0] + self.parser.process(tok) + self.last_tok = tok + else: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + + def is_expecting_start(self) -> bool: + return self.parser.state == StreamState.EXPECT_START + + def is_assistant_action_turn(self) -> bool: + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( + ) + + def render_for_completion(self) -> list[int]: + # now this list of tokens as next turn's starting tokens + # `<|start|>assistant``, + # we need to process them in parser. + rendered_tokens = super().render_for_completion() + + last_n = -1 + to_process = [] + while rendered_tokens[last_n] != self.last_tok: + to_process.append(rendered_tokens[last_n]) + last_n -= 1 + for tok in reversed(to_process): + self.parser.process(tok) + + return rendered_tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py new file mode 100644 index 0000000000..801c82b4fa --- /dev/null +++ b/vllm/entrypoints/harmony_utils.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import datetime +from typing import Literal, Optional + +from openai.types.responses.tool import Tool +from openai_harmony import (Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, + Role, StreamableParser, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date is None: + # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message(instructions: Optional[str] = None, + tools: Optional[list[Tool]] = None) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_chat_input(chat_msg) -> Message: + role = chat_msg["role"] + content = chat_msg["content"] + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c["text"]) for c in content] + msg = Message.from_role_and_contents(role, contents) + return msg + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py new file mode 100644 index 0000000000..01ee77414f --- /dev/null +++ b/vllm/entrypoints/tool.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +if TYPE_CHECKING: + # Avoid circular import. + from vllm.entrypoints.context import ConversationContext + +logger = init_logger(__name__) + + +class Tool(ABC): + + @abstractmethod + async def get_result(self, context: "ConversationContext") -> Any: + pass + + +class HarmonyBrowserTool(Tool): + + def __init__(self): + self.enabled = True + exa_api_key = os.getenv("EXA_API_KEY") + if not exa_api_key: + self.enabled = False + logger.warning_once("EXA_API_KEY is not set, browsing is disabled") + return + + try: + from gpt_oss.tools.simple_browser import SimpleBrowserTool + from gpt_oss.tools.simple_browser.backend import ExaBackend + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, browsing is disabled") + return + + browser_backend = ExaBackend(source="web", api_key=exa_api_key) + self.browser_tool = SimpleBrowserTool(backend=browser_backend) + logger.info_once("Browser tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.browser_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.browser_tool.tool_config + + +class HarmonyPythonTool(Tool): + + def __init__(self): + self.enabled = True + + try: + from gpt_oss.tools.python_docker.docker_tool import PythonTool + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, code interpreter is disabled") + return + + self.python_tool = PythonTool() + logger.info_once("Code interpreter tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.python_tool.tool_config