Compare commits
5 Commits
marlin_gpt
...
support_gl
| Author | SHA1 | Date | |
|---|---|---|---|
| 98d535eb4f | |||
| a46e279909 | |||
| 064cac7bb7 | |||
| e19bce40a1 | |||
| 505805b645 |
@ -280,6 +280,7 @@ steps:
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/offloading
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
"""
|
||||
To run this example, run the following commands simultaneously with
|
||||
@ -22,37 +25,67 @@ send a request to the instance with DP rank 1.
|
||||
"""
|
||||
|
||||
|
||||
def _do_background_logging(engine, interval, stop_event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
asyncio.run(engine.do_log_stats())
|
||||
stop_event.wait(interval)
|
||||
except Exception as e:
|
||||
print(f"vLLM background logging shutdown: {e}")
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
dtype="auto",
|
||||
max_model_len=2048,
|
||||
data_parallel_address="127.0.0.1",
|
||||
data_parallel_rpc_port=62300,
|
||||
data_parallel_size_local=1,
|
||||
enforce_eager=True,
|
||||
enable_log_requests=True,
|
||||
disable_custom_all_reduce=True,
|
||||
)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
||||
return LoggingStatLogger(config, rank)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args,
|
||||
# Example: Using both regular loggers and aggregated logger
|
||||
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
|
||||
)
|
||||
stop_logging_event = threading.Event()
|
||||
logging_thread = threading.Thread(
|
||||
target=_do_background_logging,
|
||||
args=(engine_client, 5, stop_logging_event),
|
||||
daemon=True,
|
||||
)
|
||||
logging_thread.start()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
num_prompts = 10
|
||||
for i in range(num_prompts):
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=f"abcdef-{i}",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id="abcdef",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
stop_logging_event.set()
|
||||
logging_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -28,11 +28,9 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
@pytest.fixture(scope="module")
|
||||
def server(monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
@ -57,13 +55,6 @@ def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
assert os.environ['VLLM_USE_V1'] in ['0', '1']
|
||||
return os.environ['VLLM_USE_V1'] == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
@ -481,10 +472,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_choice_chat(
|
||||
client: openai.AsyncOpenAI, sample_structured_outputs_choices,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Structured outputs is only supported in v1 engine")
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_structured_outputs_choices,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -522,12 +512,10 @@ async def test_structured_outputs_choice_chat(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Structured outputs is only supported in v1 engine")
|
||||
|
||||
async def test_structured_outputs_json_chat(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -569,10 +557,10 @@ async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_regex_chat(client: openai.AsyncOpenAI,
|
||||
sample_regex, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Structured outputs is only supported in v1 engine")
|
||||
async def test_structured_outputs_regex_chat(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_regex,
|
||||
):
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -660,10 +648,10 @@ async def test_structured_outputs_choice_chat_logprobs(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Tool use is only supported in v1 engine")
|
||||
async def test_named_tool_use(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -821,11 +809,7 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip(
|
||||
"JSON schema response format is only supported in v1 engine")
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||
prompt = 'what is 1+1? The format is "result": 2'
|
||||
# Check that this prompt cannot lead to a valid JSON without json_schema
|
||||
for _ in range(2):
|
||||
|
||||
@ -1,830 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# imports for structured outputs tests
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import regex as re
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from openai import BadRequestError
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically these adapters use a different base model,
|
||||
# but we're not testing generation quality here
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
|
||||
original_value = os.environ.get('VLLM_USE_V1')
|
||||
os.environ['VLLM_USE_V1'] = '0'
|
||||
try:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
finally:
|
||||
# Restore original env value
|
||||
if original_value is None:
|
||||
os.environ.pop('VLLM_USE_V1', None)
|
||||
else:
|
||||
os.environ['VLLM_USE_V1'] = original_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
|
||||
# For completion tests, we assume v0 since there's no explicit v1 setup
|
||||
return os.environ.get('VLLM_USE_V1', '0') == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
|
||||
# Added tokens should be rejected by the base model
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=0,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=21,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=30,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
(MODEL_NAME, 1),
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
}
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(**params)
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
assert completion.choices[1].prompt_logprobs is not None
|
||||
assert len(completion.choices[1].prompt_logprobs) > 0
|
||||
|
||||
else:
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
The tokens from multiple samples, are flattened into a single stream,
|
||||
with an index to indicate which sample the token belongs to.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
stream=True)
|
||||
chunks: list[list[str]] = [[] for i in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
text = chunk.choices[0].text
|
||||
chunks[index].append(text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == n
|
||||
for chunk in chunks:
|
||||
assert len(chunk) == max_tokens
|
||||
print("".join(chunk))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert chunk.usage is None
|
||||
else:
|
||||
assert chunk.usage is None
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is not None
|
||||
assert chunk.usage.prompt_tokens > 0
|
||||
assert chunk.usage.completion_tokens > 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": True})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": True})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test both text and token IDs
|
||||
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||
# not necessary for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logits_bias(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
token_id = 1000
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token_id): 100},
|
||||
seed=42,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
||||
add_special_tokens=False)["input_ids"]
|
||||
assert all([
|
||||
response == expected
|
||||
for response, expected in zip(response_tokens, expected_tokens)
|
||||
])
|
||||
|
||||
# Test ban
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
)
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
first_response = completion.choices[0].text
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token): -100
|
||||
for token in response_tokens},
|
||||
)
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 1
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
allowed_ids = [21555, 21557, 21558]
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
extra_body=dict(allowed_token_ids=allowed_ids),
|
||||
logprobs=1,
|
||||
)
|
||||
response_tokens = completion.choices[0].logprobs.tokens
|
||||
assert len(response_tokens) == 1
|
||||
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_json_completion(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
is_v1_server: bool,
|
||||
):
|
||||
if not is_v1_server:
|
||||
pytest.skip("structured outputs is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(structured_outputs=dict(json=sample_json_schema)))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
output_json = json.loads(completion.choices[i].text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_regex_completion(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_regex,
|
||||
is_v1_server: bool,
|
||||
):
|
||||
if not is_v1_server:
|
||||
pytest.skip("structured outputs is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(structured_outputs=dict(regex=sample_regex)))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
assert re.fullmatch(sample_regex,
|
||||
completion.choices[i].text) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_choice_completion(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_structured_outputs_choices,
|
||||
is_v1_server: bool,
|
||||
):
|
||||
if not is_v1_server:
|
||||
pytest.skip("structured outputs is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(structured_outputs=dict(
|
||||
choice=sample_structured_outputs_choices)))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 2
|
||||
for i in range(2):
|
||||
assert completion.choices[i].text in sample_structured_outputs_choices
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_grammar(client: openai.AsyncOpenAI,
|
||||
sample_sql_statements,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("grammar is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(
|
||||
structured_outputs=dict(grammar=sample_sql_statements), ))
|
||||
|
||||
content = completion.choices[0].text
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(content)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
||||
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_outputs_type_error(client: openai.AsyncOpenAI,
|
||||
sample_json_schema, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("structured outputs is only supported in v1 engine")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(structured_outputs=dict(json=42)))
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example string that fits this regex",
|
||||
extra_body=dict(structured_outputs=dict(
|
||||
regex=sample_regex,
|
||||
json=sample_json_schema,
|
||||
)))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,stream,echo",
|
||||
[
|
||||
(MODEL_NAME, False, False),
|
||||
(MODEL_NAME, False, True),
|
||||
(MODEL_NAME, True, False),
|
||||
(MODEL_NAME, True, True) # should not raise BadRequestError error
|
||||
],
|
||||
)
|
||||
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, stream: bool,
|
||||
echo: bool):
|
||||
saying: str = "Hello, my name is"
|
||||
result = await client.completions.create(model=model_name,
|
||||
prompt=saying,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
echo=echo,
|
||||
stream=stream)
|
||||
|
||||
stop_reason = "length"
|
||||
|
||||
if not stream:
|
||||
completion = result
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == stop_reason
|
||||
|
||||
if echo:
|
||||
assert choice.text is not None and saying in choice.text
|
||||
else:
|
||||
assert choice.text is not None and saying not in choice.text
|
||||
|
||||
else:
|
||||
chunks: list[str] = []
|
||||
final_finish_reason = None
|
||||
async for chunk in result:
|
||||
if chunk.choices and chunk.choices[0].text:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
final_finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
assert final_finish_reason == stop_reason
|
||||
content = "".join(chunks)
|
||||
if echo:
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello, my name is",
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
completion = await client.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["choices"] == invocation_output["choices"]
|
||||
@ -14,6 +14,9 @@ from transformers import AutoConfig
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
pytest.skip("Skipping prompt_embeds test until V1 supports it.",
|
||||
allow_module_level=True)
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@ -53,12 +53,13 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
zephyr_lora_files):
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
assert use_v1
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
# Define the json format LoRA module configurations
|
||||
lora_module_1 = {
|
||||
|
||||
@ -22,7 +22,7 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
PREV_MINOR_VERSION = version._prev_minor_version()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def use_v1(request):
|
||||
# Module-scoped variant of run_with_both_engines
|
||||
#
|
||||
|
||||
@ -10,8 +10,30 @@ import pytest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import default_server_args # noqa: F401
|
||||
from .test_completion import MODEL_NAME
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
@ -15,14 +15,6 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
class MockAggregatedStatLogger(AggregatedStatLogger):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
super().__init__(vllm_config, engine_indexes)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
@ -415,6 +424,35 @@ async def test_customize_loggers(monkeypatch):
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_aggregated_loggers(monkeypatch):
|
||||
"""Test that we can customize the aggregated loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(
|
||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
aggregated_loggers = engine.logger_manager.aggregated_loggers
|
||||
assert len(aggregated_loggers) == 1
|
||||
aggregated_loggers[0].log.assert_called_once()
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
|
||||
152
tests/v1/offloading/test_worker.py
Normal file
152
tests/v1/offloading/test_worker.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
|
||||
OffloadingWorker, TransferResult,
|
||||
TransferSpec)
|
||||
|
||||
|
||||
class LoadStoreSpec1(LoadStoreSpec):
|
||||
|
||||
def __init__(self,
|
||||
submit_success: bool = True,
|
||||
async_success: bool = True,
|
||||
exception: bool = False):
|
||||
self.finished = False
|
||||
self.submit_success = submit_success
|
||||
self.async_success = async_success
|
||||
self.exception = exception
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "1"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class LoadStoreSpec2(LoadStoreSpec):
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "2"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class OffloadingHandler1To2(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec1)
|
||||
assert isinstance(dst, LoadStoreSpec2)
|
||||
|
||||
if src.exception:
|
||||
raise Exception("An expected exception. Don't worry!")
|
||||
if not src.submit_success:
|
||||
return False
|
||||
|
||||
self.transfers[job_id] = src
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
class OffloadingHandler2To1(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec2)
|
||||
assert isinstance(dst, LoadStoreSpec1)
|
||||
|
||||
self.transfers[job_id] = dst
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
def test_offloading_worker():
|
||||
"""
|
||||
Tests OffloadingWorker with 2 handlers.
|
||||
One handler performs 1->2 transfers, and the other handles 2->1.
|
||||
"""
|
||||
worker = OffloadingWorker()
|
||||
handler1to2 = OffloadingHandler1To2()
|
||||
handler2to1 = OffloadingHandler2To1()
|
||||
worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
|
||||
worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)
|
||||
|
||||
# 1st transfer 1->2 (exception)
|
||||
src1 = LoadStoreSpec1(exception=True)
|
||||
dst1 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(1, (src1, dst1))
|
||||
|
||||
# 2ed transfer 1->2 (failure to submit)
|
||||
src2 = LoadStoreSpec1(submit_success=False)
|
||||
dst2 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(2, (src2, dst2))
|
||||
|
||||
# 3rd transfer 1->2 (failure)
|
||||
src3 = LoadStoreSpec1(async_success=False)
|
||||
dst3 = LoadStoreSpec2()
|
||||
assert worker.transfer_async(3, (src3, dst3))
|
||||
|
||||
# 4th transfer 1->2 (success)
|
||||
src4 = LoadStoreSpec1()
|
||||
dst4 = LoadStoreSpec2()
|
||||
worker.transfer_async(4, (src4, dst4))
|
||||
assert set(handler1to2.transfers.keys()) == {3, 4}
|
||||
|
||||
# 5th transfer 2->1
|
||||
src5 = LoadStoreSpec2()
|
||||
dst5 = LoadStoreSpec1()
|
||||
worker.transfer_async(5, (src5, dst5))
|
||||
assert set(handler2to1.transfers.keys()) == {5}
|
||||
|
||||
# no transfer completed yet
|
||||
assert worker.get_finished() == []
|
||||
|
||||
# complete 3rd, 4th
|
||||
src3.finished = True
|
||||
src4.finished = True
|
||||
|
||||
# 6th transfer 1->2
|
||||
src6 = LoadStoreSpec1()
|
||||
dst6 = LoadStoreSpec2()
|
||||
worker.transfer_async(6, (src6, dst6))
|
||||
|
||||
# 7th transfer 2->1
|
||||
src7 = LoadStoreSpec2()
|
||||
dst7 = LoadStoreSpec1()
|
||||
worker.transfer_async(7, (src7, dst7))
|
||||
|
||||
# 6th and 7th transfers started
|
||||
assert 6 in handler1to2.transfers
|
||||
assert 7 in handler2to1.transfers
|
||||
|
||||
# verify result of 3rd and 4th transfers
|
||||
assert (sorted(worker.get_finished()) == [(3, False), (4, True)])
|
||||
|
||||
# complete 6th and 7th transfers
|
||||
src6.finished = True
|
||||
dst7.finished = True
|
||||
assert (sorted(worker.get_finished()) == [(6, True), (7, True)])
|
||||
@ -7,7 +7,6 @@ import pytest
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -96,20 +95,3 @@ def test_v1_attn_backend(monkeypatch):
|
||||
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
|
||||
assert envs.VLLM_USE_V1
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
|
||||
def test_reject_using_constructor_directly(monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
if os.getenv("VLLM_USE_V1", None):
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
# Sets VLLM_USE_V1=1.
|
||||
vllm_config = AsyncEngineArgs(model=MODEL).create_engine_config()
|
||||
|
||||
# This uses the V0 constructor directly.
|
||||
with pytest.raises(ValueError):
|
||||
AsyncLLMEngine(vllm_config,
|
||||
AsyncLLMEngine._get_executor_cls(vllm_config),
|
||||
log_stats=True)
|
||||
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -11,7 +11,6 @@ import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
@ -154,7 +153,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
|
||||
@ -38,7 +38,6 @@ from typing_extensions import assert_never
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
@ -201,50 +200,34 @@ async def build_async_engine_client_from_engine_args(
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
# V1 AsyncLLM.
|
||||
if envs.VLLM_USE_V1:
|
||||
if disable_frontend_multiprocessing:
|
||||
logger.warning(
|
||||
"V1 is enabled, but got --disable-frontend-multiprocessing. "
|
||||
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
|
||||
assert envs.VLLM_USE_V1
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop(
|
||||
"client_count") if client_config else 1
|
||||
client_index = client_config.pop(
|
||||
"client_index") if client_config else 0
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
enable_log_requests=engine_args.enable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_count=client_count,
|
||||
client_index=client_index)
|
||||
if disable_frontend_multiprocessing:
|
||||
logger.warning(
|
||||
"V1 is enabled, but got --disable-frontend-multiprocessing. "
|
||||
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
await async_llm.reset_mm_cache()
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop("client_count") if client_config else 1
|
||||
client_index = client_config.pop("client_index") if client_config else 0
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
enable_log_requests=engine_args.enable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_count=client_count,
|
||||
client_index=client_index)
|
||||
|
||||
yield async_llm
|
||||
finally:
|
||||
if async_llm:
|
||||
async_llm.shutdown()
|
||||
# Don't keep the dummy data in memory
|
||||
await async_llm.reset_mm_cache()
|
||||
|
||||
# V0 AsyncLLM.
|
||||
else:
|
||||
|
||||
engine_client: Optional[EngineClient] = None
|
||||
try:
|
||||
engine_client = AsyncLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
enable_log_requests=engine_args.enable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats)
|
||||
yield engine_client
|
||||
finally:
|
||||
if engine_client and hasattr(engine_client, "shutdown"):
|
||||
engine_client.shutdown()
|
||||
yield async_llm
|
||||
finally:
|
||||
if async_llm:
|
||||
async_llm.shutdown()
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
|
||||
@ -76,7 +76,6 @@ class OAIAttention(nn.Module):
|
||||
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=torch.bfloat16,
|
||||
requires_grad=False))
|
||||
|
||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||
@ -145,8 +144,7 @@ class MLPBlock(torch.nn.Module):
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
dtype=torch.bfloat16)
|
||||
config.num_local_experts)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
|
||||
type["AggregatedStatLoggerBase"]]
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
@ -48,6 +50,16 @@ class StatLoggerBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AggregatedStatLoggerBase(StatLoggerBase):
|
||||
"""Abstract base class for loggers that
|
||||
aggregates statistics across multiple engines."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]]):
|
||||
...
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
@ -61,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
self.engine_is_idle = False
|
||||
|
||||
def _reset(self, now):
|
||||
self.last_log_time = now
|
||||
@ -100,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
def get_log_stats(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
self.engine_is_idle = not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput))
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
@ -128,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
@ -145,7 +158,61 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_idxs: Optional[list[int]] = None):
|
||||
if engine_idxs is None:
|
||||
engine_idxs = [0]
|
||||
self.engine_idxs = engine_idxs
|
||||
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
|
||||
|
||||
def record(
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
if engine_idx not in self.engine_idxs:
|
||||
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||
return
|
||||
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"%s Engines Aggregated: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
len(self.engine_idxs),
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"%d Engines: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", len(self.engine_idxs),
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(AggregatedStatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
@ -674,23 +741,32 @@ class StatLoggerManager:
|
||||
|
||||
# engine_idx: StatLogger
|
||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||
prometheus_factory = PrometheusStatLogger
|
||||
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
|
||||
|
||||
aggregated_loggers_factories = set()
|
||||
for engine_idx in self.engine_idxs:
|
||||
loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
# If we get a custom prometheus logger, use that
|
||||
# instead. This is typically used for the ray case.
|
||||
if (isinstance(logger_factory, type)
|
||||
and issubclass(logger_factory, PrometheusStatLogger)):
|
||||
prometheus_factory = logger_factory
|
||||
continue
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
# If we get a custom prometheus logger or aggregated logger,
|
||||
# We initialize it separately with all engine idxs.
|
||||
# A custom prometheus logger is typically used for the ray.
|
||||
if (isinstance(logger_factory, type) and issubclass(
|
||||
logger_factory, AggregatedStatLoggerBase)):
|
||||
aggregated_loggers_factories.add(logger_factory)
|
||||
else:
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
self.per_engine_logger_dict[engine_idx] = loggers
|
||||
|
||||
# For Prometheus, need to share the metrics between EngineCores.
|
||||
# If no custom aggregated logger is provide,
|
||||
# we by default use PrometheusStatLogger
|
||||
if not aggregated_loggers_factories:
|
||||
aggregated_loggers_factories.add(PrometheusStatLogger)
|
||||
# For custom aggregated logger(or default Prometheus Logger)
|
||||
# need to share the metrics between EngineCores.
|
||||
# Each EngineCore's metrics are expressed as a unique label.
|
||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
||||
for aggregated_loggers_factory in aggregated_loggers_factories:
|
||||
self.aggregated_loggers.append(
|
||||
aggregated_loggers_factory(vllm_config, engine_idxs))
|
||||
|
||||
def record(
|
||||
self,
|
||||
@ -704,18 +780,19 @@ class StatLoggerManager:
|
||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||
for logger in per_engine_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log()
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.log()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.prometheus_logger.log_engine_initialized()
|
||||
|
||||
for agg_logger in self.aggregated_loggers:
|
||||
agg_logger.log_engine_initialized()
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log_engine_initialized()
|
||||
|
||||
165
vllm/v1/offloading/abstract.py
Normal file
165
vllm/v1/offloading/abstract.py
Normal file
@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
OffloadingManager class for managing KV data offloading in vLLM v1
|
||||
|
||||
This class runs in the scheduler, tracks which blocks are offloaded
|
||||
and their address.
|
||||
|
||||
The class provides the following primitives:
|
||||
lookup() - find the length of the maximal series of blocks,
|
||||
starting from the first one, that are all offloaded.
|
||||
prepare_load() - prepare given blocks to be read.
|
||||
The given blocks will be protected from eviction.
|
||||
This function returns a LoadSpec which encapsulates
|
||||
information required for performing the load.
|
||||
touch() - marks the give blocks as recently used. Can be used
|
||||
to track block's LRU. This function is separated from the
|
||||
prepare_load function to allow setting block recency even
|
||||
for blocks which do not need reading from the cache, such as
|
||||
blocks that are cached by the GPU prefix cache.
|
||||
complete_load() - mark blocks which were previously prepared to be
|
||||
loaded as done loading. This is to re-allow their eviction.
|
||||
prepare_store() - prepare the given blocks to be written.
|
||||
Returns a StoreSpec encapsulating offloading information,
|
||||
as well as a list of blocks that were evicted as a result.
|
||||
complete_store() - marks a previous store as completed.
|
||||
Following this call, the given blocks will become loadable.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class LoadStoreSpec(ABC):
|
||||
"""
|
||||
Abstract metadata that encapsulates information allowing a worker
|
||||
to load, and optionally also to store, blocks of KV data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def medium() -> str:
|
||||
"""
|
||||
Returns a string representation of the medium type
|
||||
this store/load targets.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareStoreOutput:
|
||||
block_hashes_to_store: list[BlockHash]
|
||||
store_spec: LoadStoreSpec
|
||||
block_hashes_evicted: list[BlockHash]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingEvent:
|
||||
block_hashes: list[BlockHash]
|
||||
block_size: int
|
||||
medium: str
|
||||
# True if blocks are removed, False if stored
|
||||
removed: bool
|
||||
|
||||
|
||||
class OffloadingManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
"""
|
||||
Finds the length of the maximal series of blocks, starting from the
|
||||
first one, that are all offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks to lookup.
|
||||
|
||||
Returns:
|
||||
An integer representing the maximal number of blocks that
|
||||
are currently offloaded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
"""
|
||||
Prepare the given blocks to be read.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_load is called.
|
||||
It assumes all given blocks are offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A LoadStoreSpec that can be used by a worker to locate and load
|
||||
the actual offloaded KV data.
|
||||
"""
|
||||
pass
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Mark the given blocks as recently used.
|
||||
This could in practice mean moving them to the end of an LRU list.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Marks previous blocks that were prepared to load as done loading.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def prepare_store(
|
||||
self,
|
||||
block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
|
||||
"""
|
||||
Prepare the given blocks to be offloaded.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_store is called.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A PrepareStoreOutput indicating which blocks need storing,
|
||||
where to store them (LoadStoreSpec), and list of blocks that
|
||||
were evicted as a result.
|
||||
None is returned if the blocks cannot be stored.
|
||||
"""
|
||||
pass
|
||||
|
||||
def complete_store(self,
|
||||
block_hashes: Iterable[BlockHash],
|
||||
success: bool = True):
|
||||
"""
|
||||
Marks blocks which were previously prepared to be stored, as stored.
|
||||
Following this call, the blocks become loadable.
|
||||
If if_success is False, blocks that were not marked as stored will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
success: whether the blocks were stored successfully.
|
||||
"""
|
||||
return
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
"""
|
||||
Take the offloading events from the manager.
|
||||
|
||||
Yields:
|
||||
New OffloadingEvents collected since the last call.
|
||||
"""
|
||||
return ()
|
||||
39
vllm/v1/offloading/mediums.py
Normal file
39
vllm/v1/offloading/mediums.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
|
||||
|
||||
class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):
|
||||
"""
|
||||
Spec for loading/storing KV blocks from given block numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, block_ids: list[int]):
|
||||
self.block_ids = np.array(block_ids, dtype=np.int64)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_ids)
|
||||
|
||||
|
||||
class GPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to GPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "GPU"
|
||||
|
||||
|
||||
class CPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to CPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "CPU"
|
||||
142
vllm/v1/offloading/worker/worker.py
Normal file
142
vllm/v1/offloading/worker/worker.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.offloading.abstract import LoadStoreSpec
|
||||
|
||||
# a single transfer spec (src_blocks_spec, dst_blocks_spec)
|
||||
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]
|
||||
# transfers are forwarded to workers by (src_medium, dst_medium)
|
||||
TransferType = tuple[str, str]
|
||||
# transfer result (job_id, success)
|
||||
TransferResult = tuple[int, bool]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingHandler(ABC):
|
||||
"""
|
||||
OffloadingHandler class for managing asynchronous KV data transfers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, and allows
|
||||
collecting back completion statuses.
|
||||
|
||||
The class provides the following primitives:
|
||||
transfer_async() - kicks off a new transfer job
|
||||
get_finished() - returns a list of newly finished job IDs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OffloadingWorker:
|
||||
"""
|
||||
OffloadingWorker class for managing asynchronous KV data transfers
|
||||
using multiple OffloadingHandlers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, by delegating
|
||||
to one of its registered OffloadingHandlers, based on the transfer type.
|
||||
|
||||
The class provides the following primitives:
|
||||
register_handler() - registers a new handler to handle
|
||||
a specific transfer type
|
||||
transfer_async() - kicks off a new transfer job
|
||||
using one of the registered handlers.
|
||||
get_finished() - returns a list of newly finished job IDs
|
||||
from all handlers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers: set[OffloadingHandler] = set()
|
||||
self.transfer_type_to_handler: dict[TransferType,
|
||||
OffloadingHandler] = {}
|
||||
|
||||
def register_handler(self, src_cls: type[LoadStoreSpec],
|
||||
dst_cls: type[LoadStoreSpec],
|
||||
handler: OffloadingHandler) -> None:
|
||||
"""
|
||||
Registers a new handler.
|
||||
|
||||
Args:
|
||||
src_cls: the source type of transfers handled by this handler.
|
||||
dst_cls: the destination type of transfers handled by this handler.
|
||||
handler: the handler that will handle transfers.
|
||||
"""
|
||||
transfer_type = (src_cls.medium(), dst_cls.medium())
|
||||
assert transfer_type not in self.transfer_type_to_handler
|
||||
self.handlers.add(handler)
|
||||
self.transfer_type_to_handler[transfer_type] = handler
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
src, dst = spec
|
||||
transfer_type = (src.medium(), dst.medium())
|
||||
handler = self.transfer_type_to_handler.get(transfer_type)
|
||||
assert handler is not None
|
||||
|
||||
try:
|
||||
success = handler.transfer_async(job_id, spec)
|
||||
except Exception as e:
|
||||
logger.warning("Exception in %r transfer %d: %r",
|
||||
transfer_type,
|
||||
job_id,
|
||||
e,
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
if not success:
|
||||
logger.warning("Failed to submit %r transfer %d", transfer_type,
|
||||
job_id)
|
||||
else:
|
||||
logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id,
|
||||
spec)
|
||||
|
||||
return success
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
finished = []
|
||||
for handler in self.handlers:
|
||||
finished.extend(handler.get_finished())
|
||||
return finished
|
||||
Reference in New Issue
Block a user