Support Anthropic API /v1/messages Endpoint (#22627)
Signed-off-by: liuli <ll407707@alibaba-inc.com> Co-authored-by: liuli <ll407707@alibaba-inc.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -48,3 +48,4 @@ pybase64 # fast base64 implementation
|
||||
cbor2 # Required for cross-language serialization of hashable objects
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
|
||||
0
tests/entrypoints/anthropic/__init__.py
Normal file
0
tests/entrypoints/anthropic/__init__.py
Normal file
141
tests/entrypoints/anthropic/test_messages.py
Normal file
141
tests/entrypoints/anthropic/test_messages.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteAnthropicServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(): # noqa: F811
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--served-model-name",
|
||||
"claude-3-7-sonnet-latest",
|
||||
]
|
||||
|
||||
with RemoteAnthropicServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_messages(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "how are you!"}],
|
||||
)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
assert resp.role == "assistant"
|
||||
|
||||
print(f"Anthropic response: {resp.model_dump_json()}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_message(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
system="you are a helpful assistant",
|
||||
messages=[{"role": "user", "content": "how are you!"}],
|
||||
)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
assert resp.role == "assistant"
|
||||
|
||||
print(f"Anthropic response: {resp.model_dump_json()}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "how are you!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk.model_dump_json())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather like in New York today?"}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Useful for querying the weather in a specified city.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City or region, for example: "
|
||||
"New York, London, Tokyo, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
assert resp.stop_reason == "tool_use"
|
||||
assert resp.role == "assistant"
|
||||
|
||||
print(f"Anthropic response: {resp.model_dump_json()}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in New York today?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Useful for querying the weather "
|
||||
"in a specified city.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City or region, for example: "
|
||||
"New York, London, Tokyo, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk.model_dump_json())
|
||||
126
tests/utils.py
126
tests/utils.py
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import anthropic
|
||||
import cloudpickle
|
||||
import httpx
|
||||
import openai
|
||||
@ -294,6 +295,131 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
self.proc.kill()
|
||||
|
||||
|
||||
class RemoteAnthropicServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError(
|
||||
"You have manually specified the port when `auto_port=True`."
|
||||
)
|
||||
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError(
|
||||
f"You have manually specified the seed when `seed={seed}`."
|
||||
)
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.")
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or "localhost")
|
||||
self.port = int(args.port)
|
||||
|
||||
self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
model_config = engine_args.create_model_config()
|
||||
load_config = engine_args.create_load_config()
|
||||
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.anthropic.api_server",
|
||||
model,
|
||||
*vllm_serve_args,
|
||||
],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self.proc.poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError("Server failed to start in time.") from None
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.Anthropic(
|
||||
base_url=self.url_for(),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.AsyncAnthropic(
|
||||
base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _test_completion(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
|
||||
0
vllm/entrypoints/anthropic/__init__.py
Normal file
0
vllm/entrypoints/anthropic/__init__.py
Normal file
300
vllm/entrypoints/anthropic/api_server.py
Normal file
300
vllm/entrypoints/anthropic/api_server.py
Normal file
@ -0,0 +1,300 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from:
|
||||
# https://github.com/vllm/vllm/entrypoints/openai/api_server.py
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client,
|
||||
create_server_socket,
|
||||
lifespan,
|
||||
load_log_config,
|
||||
validate_api_server_args,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_models import (
|
||||
BaseModelPath,
|
||||
OpenAIServingModels,
|
||||
)
|
||||
|
||||
#
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
load_aware_call,
|
||||
process_chat_template,
|
||||
process_lora_modules,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger("vllm.entrypoints.anthropic.api_server")
|
||||
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/ping", response_class=Response)
|
||||
@router.post("/ping", response_class=Response)
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
return messages(raw_request).create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
|
||||
generator = await handler.create_messages(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
elif isinstance(generator, AnthropicMessagesResponse):
|
||||
logger.debug(
|
||||
"Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
|
||||
)
|
||||
return JSONResponse(content=generator.model_dump(exclude_none=True))
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
else {}
|
||||
)
|
||||
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
|
||||
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, model_config
|
||||
)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.anthropic_serving_messages = AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
|
||||
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock_addr = (args.host or "", args.port)
|
||||
sock = create_server_socket(sock_addr)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
set_ulimit()
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server_worker(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
uvicorn_kwargs["log_config"] = log_config
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index, listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
||||
# no access log will be output.
|
||||
access_log=not args.disable_uvicorn_access_log,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
||||
# entrypoints.
|
||||
cli_env_setup()
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM Anthropic-Compatible RESTful API server."
|
||||
)
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
162
vllm/entrypoints/anthropic/protocol.py
Normal file
162
vllm/entrypoints/anthropic/protocol.py
Normal file
@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pydantic models for Anthropic API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class AnthropicError(BaseModel):
|
||||
"""Error structure for Anthropic API"""
|
||||
|
||||
type: str
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(BaseModel):
|
||||
"""Error response structure for Anthropic API"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: AnthropicError
|
||||
|
||||
|
||||
class AnthropicUsage(BaseModel):
|
||||
"""Token usage information"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int | None = None
|
||||
cache_read_input_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicContentBlock(BaseModel):
|
||||
"""Content block in message"""
|
||||
|
||||
type: Literal["text", "image", "tool_use", "tool_result"]
|
||||
text: str | None = None
|
||||
# For image content
|
||||
source: dict[str, Any] | None = None
|
||||
# For tool use/result
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
is_error: bool | None = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
"""Message structure"""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: str | list[AnthropicContentBlock]
|
||||
|
||||
|
||||
class AnthropicTool(BaseModel):
|
||||
"""Tool definition"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
@field_validator("input_schema")
|
||||
@classmethod
|
||||
def validate_input_schema(cls, v):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("input_schema must be a dictionary")
|
||||
if "type" not in v:
|
||||
v["type"] = "object" # Default to object type
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicToolChoice(BaseModel):
|
||||
"""Tool Choice definition"""
|
||||
|
||||
type: Literal["auto", "any", "tool"]
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class AnthropicMessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request"""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
max_tokens: int
|
||||
metadata: dict[str, Any] | None = None
|
||||
stop_sequences: list[str] | None = None
|
||||
stream: bool | None = False
|
||||
system: str | list[AnthropicContentBlock] | None = None
|
||||
temperature: float | None = None
|
||||
tool_choice: AnthropicToolChoice | None = None
|
||||
tools: list[AnthropicTool] | None = None
|
||||
top_k: int | None = None
|
||||
top_p: float | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model is required")
|
||||
return v
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
def validate_max_tokens(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError("max_tokens must be positive")
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicDelta(BaseModel):
|
||||
"""Delta for streaming responses"""
|
||||
|
||||
type: Literal["text_delta", "input_json_delta"] | None = None
|
||||
text: str | None = None
|
||||
partial_json: str | None = None
|
||||
|
||||
# Message delta
|
||||
stop_reason: (
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
|
||||
) = None
|
||||
stop_sequence: str | None = None
|
||||
|
||||
|
||||
class AnthropicStreamEvent(BaseModel):
|
||||
"""Streaming event"""
|
||||
|
||||
type: Literal[
|
||||
"message_start",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"ping",
|
||||
"error",
|
||||
]
|
||||
message: Optional["AnthropicMessagesResponse"] = None
|
||||
delta: AnthropicDelta | None = None
|
||||
content_block: AnthropicContentBlock | None = None
|
||||
index: int | None = None
|
||||
error: AnthropicError | None = None
|
||||
usage: AnthropicUsage | None = None
|
||||
|
||||
|
||||
class AnthropicMessagesResponse(BaseModel):
|
||||
"""Anthropic Messages API response"""
|
||||
|
||||
id: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[AnthropicContentBlock]
|
||||
model: str
|
||||
stop_reason: (
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
|
||||
) = None
|
||||
stop_sequence: str | None = None
|
||||
usage: AnthropicUsage | None = None
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if not self.id:
|
||||
self.id = f"msg_{int(time.time() * 1000)}"
|
||||
458
vllm/entrypoints/anthropic/serving_messages.py
Normal file
458
vllm/entrypoints/anthropic/serving_messages.py
Normal file
@ -0,0 +1,458 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from
|
||||
# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py
|
||||
|
||||
"""Anthropic Messages API serving handler"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicDelta,
|
||||
AnthropicError,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicStreamEvent,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam,
|
||||
ErrorResponse,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_data_with_event(data: str, event: str):
|
||||
return f"event: {event}\ndata: {data}\n\n"
|
||||
|
||||
|
||||
class AnthropicServingMessages(OpenAIServingChat):
|
||||
"""Handler for Anthropic Messages API requests"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
response_role: str,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
reasoning_parser: str = "",
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: str | None = None,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
response_role=response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_auto_tools=enable_auto_tools,
|
||||
tool_parser=tool_parser,
|
||||
enable_prompt_tokens_details=enable_prompt_tokens_details,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
self.stop_reason_map = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
}
|
||||
|
||||
def _convert_anthropic_to_openai_request(
|
||||
self, anthropic_request: AnthropicMessagesRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""Convert Anthropic message format to OpenAI format"""
|
||||
openai_messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if anthropic_request.system:
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
)
|
||||
else:
|
||||
system_prompt = ""
|
||||
for block in anthropic_request.system:
|
||||
if block.type == "text" and block.text:
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for msg in anthropic_request.messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
# Handle complex content blocks
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for block in msg.content:
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": block.source.get("data", "")},
|
||||
}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
# Convert tool use to function call format
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name or "",
|
||||
"arguments": json.dumps(block.input or {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif block.type == "tool_result":
|
||||
if msg.role == "user":
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.id or "",
|
||||
"content": str(block.content)
|
||||
if block.content
|
||||
else "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Assistant tool result becomes regular text
|
||||
tool_result_text = (
|
||||
str(block.content) if block.content else ""
|
||||
)
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tool result: {tool_result_text}",
|
||||
}
|
||||
)
|
||||
|
||||
# Add tool calls to the message if any
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
# Add content parts if any
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls:
|
||||
continue
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
max_tokens=anthropic_request.max_tokens,
|
||||
max_completion_tokens=anthropic_request.max_tokens,
|
||||
stop=anthropic_request.stop_sequences,
|
||||
temperature=anthropic_request.temperature,
|
||||
top_p=anthropic_request.top_p,
|
||||
top_k=anthropic_request.top_k,
|
||||
)
|
||||
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate({"include_usage": True})
|
||||
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
elif anthropic_request.tool_choice.type == "auto":
|
||||
req.tool_choice = "auto"
|
||||
elif anthropic_request.tool_choice.type == "any":
|
||||
req.tool_choice = "required"
|
||||
elif anthropic_request.tool_choice.type == "tool":
|
||||
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": anthropic_request.tool_choice.name},
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
if anthropic_request.tools is None:
|
||||
return req
|
||||
for tool in anthropic_request.tools:
|
||||
tools.append(
|
||||
ChatCompletionToolsParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
if req.tool_choice is None:
|
||||
req.tool_choice = "auto"
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
async def create_messages(
|
||||
self,
|
||||
request: AnthropicMessagesRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | AnthropicMessagesResponse | ErrorResponse:
|
||||
"""
|
||||
Messages API similar to Anthropic's API.
|
||||
|
||||
See https://docs.anthropic.com/en/api/messages
|
||||
for the API specification. This API mimics the Anthropic messages API.
|
||||
"""
|
||||
logger.debug("Received messages request %s", request.model_dump_json())
|
||||
chat_req = self._convert_anthropic_to_openai_request(request)
|
||||
logger.debug("Convert to OpenAI request %s", request.model_dump_json())
|
||||
generator = await self.create_chat_completion(chat_req, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return generator
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return self.messages_full_converter(generator)
|
||||
|
||||
return self.message_stream_converter(generator)
|
||||
|
||||
def messages_full_converter(
|
||||
self,
|
||||
generator: ChatCompletionResponse,
|
||||
) -> AnthropicMessagesResponse:
|
||||
result = AnthropicMessagesResponse(
|
||||
id=generator.id,
|
||||
content=[],
|
||||
model=generator.model,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=generator.usage.prompt_tokens,
|
||||
output_tokens=generator.usage.completion_tokens,
|
||||
),
|
||||
)
|
||||
if generator.choices[0].finish_reason == "stop":
|
||||
result.stop_reason = "end_turn"
|
||||
elif generator.choices[0].finish_reason == "length":
|
||||
result.stop_reason = "max_tokens"
|
||||
elif generator.choices[0].finish_reason == "tool_calls":
|
||||
result.stop_reason = "tool_use"
|
||||
|
||||
content: list[AnthropicContentBlock] = [
|
||||
AnthropicContentBlock(
|
||||
type="text",
|
||||
text=generator.choices[0].message.content
|
||||
if generator.choices[0].message.content
|
||||
else "",
|
||||
)
|
||||
]
|
||||
|
||||
for tool_call in generator.choices[0].message.tool_calls:
|
||||
anthropic_tool_call = AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
content += [anthropic_tool_call]
|
||||
|
||||
result.content = content
|
||||
|
||||
return result
|
||||
|
||||
async def message_stream_converter(
|
||||
self,
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
first_item = True
|
||||
finish_reason = None
|
||||
content_block_index = 0
|
||||
content_block_started = False
|
||||
|
||||
async for item in generator:
|
||||
if item.startswith("data:"):
|
||||
data_str = item[5:].strip().rstrip("\n")
|
||||
if data_str == "[DONE]":
|
||||
stop_message = AnthropicStreamEvent(
|
||||
type="message_stop",
|
||||
)
|
||||
data = stop_message.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
yield wrap_data_with_event(data, "message_stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
else:
|
||||
origin_chunk = ChatCompletionStreamResponse.model_validate_json(
|
||||
data_str
|
||||
)
|
||||
|
||||
if first_item:
|
||||
chunk = AnthropicStreamEvent(
|
||||
type="message_start",
|
||||
message=AnthropicMessagesResponse(
|
||||
id=origin_chunk.id,
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
),
|
||||
)
|
||||
first_item = False
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "message_start")
|
||||
continue
|
||||
|
||||
# last chunk including usage info
|
||||
if len(origin_chunk.choices) == 0:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_stop")
|
||||
stop_reason = self.stop_reason_map.get(
|
||||
finish_reason or "stop"
|
||||
)
|
||||
chunk = AnthropicStreamEvent(
|
||||
type="message_delta",
|
||||
delta=AnthropicDelta(stop_reason=stop_reason),
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
output_tokens=origin_chunk.usage.completion_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "message_delta")
|
||||
continue
|
||||
|
||||
if origin_chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = origin_chunk.choices[0].finish_reason
|
||||
continue
|
||||
|
||||
# content
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if not content_block_started:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="text", text=""
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
|
||||
if origin_chunk.choices[0].delta.content == "":
|
||||
continue
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="text_delta",
|
||||
text=origin_chunk.choices[0].delta.content,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
|
||||
# tool calls
|
||||
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.id is not None:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(
|
||||
exclude_unset=True
|
||||
)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_stop"
|
||||
)
|
||||
content_block_started = False
|
||||
content_block_index += 1
|
||||
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name
|
||||
if tool_call.function
|
||||
else None,
|
||||
input={},
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
|
||||
else:
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments
|
||||
if tool_call.function
|
||||
else None,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
else:
|
||||
error_response = AnthropicStreamEvent(
|
||||
type="error",
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message="Invalid data format received",
|
||||
),
|
||||
)
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in message stream converter.")
|
||||
error_response = AnthropicStreamEvent(
|
||||
type="error",
|
||||
error=AnthropicError(type="internal_error", message=str(e)),
|
||||
)
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
@ -41,11 +41,6 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
resolve_mistral_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
@ -90,7 +85,6 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (
|
||||
BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
@ -107,11 +101,12 @@ from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
load_aware_call,
|
||||
log_non_default_args,
|
||||
process_chat_template,
|
||||
process_lora_modules,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (
|
||||
Device,
|
||||
@ -1655,32 +1650,9 @@ async def init_app_state(
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
tokenizer = await engine_client.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# The warning is logged in resolve_mistral_chat_template.
|
||||
resolved_chat_template = resolve_mistral_chat_template(
|
||||
chat_template=resolved_chat_template
|
||||
)
|
||||
else:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
model_config=vllm_config.model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template != resolved_chat_template:
|
||||
logger.warning(
|
||||
"Using supplied chat template: %s\n"
|
||||
"It is different from official chat template '%s'. "
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template,
|
||||
args.model,
|
||||
)
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, vllm_config.model_config
|
||||
)
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: ToolServer | None = DemoToolServer()
|
||||
@ -1699,19 +1671,12 @@ async def init_app_state(
|
||||
else {}
|
||||
)
|
||||
|
||||
lora_modules = args.lora_modules
|
||||
if default_mm_loras:
|
||||
default_mm_lora_paths = [
|
||||
LoRAModulePath(
|
||||
name=modality,
|
||||
path=lora_path,
|
||||
)
|
||||
for modality, lora_path in default_mm_loras.items()
|
||||
]
|
||||
if args.lora_modules is None:
|
||||
lora_modules = default_mm_lora_paths
|
||||
else:
|
||||
lora_modules += default_mm_lora_paths
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
else {}
|
||||
)
|
||||
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
|
||||
@ -6,21 +6,31 @@ import dataclasses
|
||||
import functools
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
resolve_mistral_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -254,3 +264,56 @@ def should_include_usage(
|
||||
else:
|
||||
include_usage, include_continuous_usage = enable_force_include_usage, False
|
||||
return include_usage, include_continuous_usage
|
||||
|
||||
|
||||
def process_lora_modules(
|
||||
args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None
|
||||
) -> list[LoRAModulePath]:
|
||||
lora_modules = args_lora_modules
|
||||
if default_mm_loras:
|
||||
default_mm_lora_paths = [
|
||||
LoRAModulePath(
|
||||
name=modality,
|
||||
path=lora_path,
|
||||
)
|
||||
for modality, lora_path in default_mm_loras.items()
|
||||
]
|
||||
if args_lora_modules is None:
|
||||
lora_modules = default_mm_lora_paths
|
||||
else:
|
||||
lora_modules += default_mm_lora_paths
|
||||
return lora_modules
|
||||
|
||||
|
||||
async def process_chat_template(
|
||||
args_chat_template: Path | str | None,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
resolved_chat_template = load_chat_template(args_chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
tokenizer = await engine_client.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# The warning is logged in resolve_mistral_chat_template.
|
||||
resolved_chat_template = resolve_mistral_chat_template(
|
||||
chat_template=resolved_chat_template
|
||||
)
|
||||
else:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template != resolved_chat_template:
|
||||
logger.warning(
|
||||
"Using supplied chat template: %s\n"
|
||||
"It is different from official chat template '%s'. "
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template,
|
||||
model_config.model,
|
||||
)
|
||||
return resolved_chat_template
|
||||
|
||||
Reference in New Issue
Block a user