Compare commits
14 Commits
v0.7.0
...
tpu_v1_opt
| Author | SHA1 | Date | |
|---|---|---|---|
| 70b4e46e70 | |||
| 5fb9dbe6f6 | |||
| 996b92ccb4 | |||
| 2b0526fa15 | |||
| 7be649256f | |||
| 627efde813 | |||
| c2867d5bc1 | |||
| 39c4a4cdb5 | |||
| 1ccf100c6a | |||
| 248c5b632d | |||
| 950f349492 | |||
| 61bb55f3d5 | |||
| 0bddb6b9a5 | |||
| c715fb19e5 |
@ -2,11 +2,8 @@ import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB
|
||||
# Note that we have 400 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/3792 .
|
||||
# Please also sync the value with the one in Dockerfile.
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300))
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
||||
@ -89,4 +89,4 @@ repos:
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
verbose: true
|
||||
6
CMakeLists.txt
Executable file → Normal file
6
CMakeLists.txt
Executable file → Normal file
@ -446,9 +446,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_C_LIBS cuda)
|
||||
endif()
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
@ -457,7 +454,6 @@ define_gpu_extension_target(
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
LIBRARIES ${VLLM_C_LIBS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@ -580,7 +576,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c
|
||||
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
# sync the default value with .buildkite/check-wheel-size.py
|
||||
ARG VLLM_MAX_SIZE_MB=300
|
||||
# Default max size of the wheel is 250MB
|
||||
ARG VLLM_MAX_SIZE_MB=250
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20250124"
|
||||
ARG NIGHTLY_DATE="20250122"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -51,8 +51,7 @@ async def async_request_tgi(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
@ -124,8 +123,7 @@ async def async_request_trt_llm(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
@ -189,8 +187,7 @@ async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
|
||||
payload = {
|
||||
@ -238,8 +235,7 @@ async def async_request_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
@ -337,8 +333,7 @@ async def async_request_openai_chat_completions(
|
||||
"chat/completions"
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
|
||||
@ -200,7 +200,7 @@ def sample_sonnet_requests(
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_vision_arena_requests(
|
||||
def sample_mmmu_pro_vision_requests(
|
||||
dataset,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -212,7 +212,13 @@ def sample_vision_arena_requests(
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
prompt = data["turns"][0][0]['content']
|
||||
# MMMU-Pro vision direct prompt
|
||||
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
|
||||
prompt = (
|
||||
"Answer with the option letter from the given choices directly. "
|
||||
"The last line of your response should be of the following "
|
||||
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
|
||||
"options.")
|
||||
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
if fixed_output_len is None:
|
||||
@ -224,10 +230,10 @@ def sample_vision_arena_requests(
|
||||
output_len = fixed_output_len
|
||||
|
||||
assert isinstance(
|
||||
data["images"][0],
|
||||
data["image"],
|
||||
Image), ("Input image format must be `PIL.Image.Image`, "
|
||||
f"given {type(data['image'])}.")
|
||||
image: Image = data["images"][0]
|
||||
image: Image = data["image"]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
@ -246,7 +252,7 @@ def sample_vision_arena_requests(
|
||||
|
||||
def sample_hf_requests(
|
||||
dataset_path: str,
|
||||
dataset_subset: Optional[str],
|
||||
dataset_subset: str,
|
||||
dataset_split: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -254,17 +260,19 @@ def sample_hf_requests(
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||
|
||||
# Special case for vision_arena dataset
|
||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||
and dataset_subset is None:
|
||||
assert dataset_split == "train"
|
||||
# Special case for MMMU-Pro vision dataset
|
||||
if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
|
||||
assert dataset_split == "test"
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
dataset = dataset.shuffle(seed=random_seed)
|
||||
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
|
||||
fixed_output_len)
|
||||
assert "image" in dataset.features, (
|
||||
"MMMU/MMMU_Pro vision dataset must have 'image' column.")
|
||||
filter_func = lambda x: isinstance(x["image"], Image)
|
||||
dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||
return sample_mmmu_pro_vision_requests(dataset, num_requests,
|
||||
tokenizer, fixed_output_len)
|
||||
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
|
||||
@ -33,9 +33,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
|
||||
token_cnts_t* tokens_cnts =
|
||||
(token_cnts_t*)(shared_mem + num_experts +
|
||||
1); // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
# vLLM Blog
|
||||
|
||||
vLLM blog posts are published [here](https://blog.vllm.ai/).
|
||||
@ -59,7 +59,6 @@ To build and install vLLM from source, run:
|
||||
```console
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install -r requirements-hpu.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
@ -69,7 +68,6 @@ Currently, the latest features and performance optimizations are developed in Ga
|
||||
git clone https://github.com/HabanaAI/vllm-fork.git
|
||||
cd vllm-fork
|
||||
git checkout habana_main
|
||||
pip install -r requirements-hpu.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
|
||||
@ -184,7 +184,6 @@ api/model/index
|
||||
:caption: Community
|
||||
:maxdepth: 1
|
||||
|
||||
community/blog
|
||||
community/meetups
|
||||
community/sponsors
|
||||
```
|
||||
|
||||
@ -50,11 +50,6 @@ In addition, we have the following custom APIs:
|
||||
- Applicable to all [pooling models](../models/pooling_models.md).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
|
||||
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
|
||||
- Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response.
|
||||
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
|
||||
(chat-template)=
|
||||
|
||||
@ -478,90 +473,3 @@ The following extra parameters are supported:
|
||||
:start-after: begin-score-extra-params
|
||||
:end-before: end-score-extra-params
|
||||
```
|
||||
|
||||
(rerank-api)=
|
||||
|
||||
### Re-rank API
|
||||
|
||||
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
|
||||
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
|
||||
a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
|
||||
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
|
||||
endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py>
|
||||
|
||||
#### Example Request
|
||||
|
||||
Note that the `top_n` request parameter is optional and will default to the length of the `documents` field.
|
||||
Result documents will be sorted by relevance, and the `index` property can be used to determine original order.
|
||||
|
||||
Request:
|
||||
|
||||
```bash
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/v1/rerank' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Horses and cows are both animals"
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "rerank-fae51b2b664d4ed38f5969b612edff77",
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"usage": {
|
||||
"total_tokens": 56
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"index": 1,
|
||||
"document": {
|
||||
"text": "The capital of France is Paris."
|
||||
},
|
||||
"relevance_score": 0.99853515625
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"document": {
|
||||
"text": "The capital of Brazil is Brasilia."
|
||||
},
|
||||
"relevance_score": 0.0005860328674316406
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters](#pooling-params) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-rerank-pooling-params
|
||||
:end-before: end-rerank-pooling-params
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-rerank-extra-params
|
||||
:end-before: end-rerank-extra-params
|
||||
```
|
||||
|
||||
@ -8,10 +8,10 @@ prompts = [
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
@ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
@ -13,7 +13,7 @@ The OpenAI batch file format consists of a series of json objects on new lines.
|
||||
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
|
||||
|
||||
```{note}
|
||||
We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` endpoints (completions coming soon).
|
||||
We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).
|
||||
```
|
||||
|
||||
## Pre-requisites
|
||||
@ -203,34 +203,3 @@ $ cat results.jsonl
|
||||
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
|
||||
...
|
||||
```
|
||||
|
||||
## Example 5: Using score endpoint
|
||||
|
||||
### Additional prerequisites
|
||||
|
||||
* Ensure you are using `vllm >= 0.7.0`.
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
Add score requests to your batch file. The following is an example:
|
||||
|
||||
```
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
```
|
||||
|
||||
You can mix chat completion, embedding, and score requests in the batch file, as long as the model you are using supports them all (note that all requests must use the same model).
|
||||
|
||||
### Step 2: Run the batch
|
||||
|
||||
You can run the batch using the same command as in earlier examples.
|
||||
|
||||
### Step 3: Check your results
|
||||
|
||||
You can check your results by running `cat results.jsonl`
|
||||
|
||||
```
|
||||
$ cat results.jsonl
|
||||
{"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
|
||||
{"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
|
||||
```
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
import cohere
|
||||
|
||||
# cohere v1 client
|
||||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
||||
rerank_v1_result = co.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
|
||||
print(rerank_v1_result)
|
||||
|
||||
# or the v2
|
||||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
||||
|
||||
v2_rerank_result = co2.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
|
||||
print(v2_rerank_result)
|
||||
@ -1,33 +0,0 @@
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
Jina and Cohere https://jina.ai/reranker
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model":
|
||||
"BAAI/bge-reranker-base",
|
||||
"query":
|
||||
"What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.", "Horses and cows are both animals"
|
||||
]
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
@ -5,7 +5,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
|
||||
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||
@ -34,6 +34,6 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.9.0 # required for compressed-tensors
|
||||
compressed-tensors == 0.9.1 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
#
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
@ -106,9 +106,17 @@ dnspython==2.7.0
|
||||
docutils==0.16
|
||||
# via awscli
|
||||
einops==0.8.0
|
||||
# via -r requirements-test.in
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
# via vector-quantize-pytorch
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
encodec==0.1.1
|
||||
# via vocos
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastparquet==2024.11.0
|
||||
@ -125,6 +133,8 @@ filelock==3.16.1
|
||||
# triton
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
frozendict==2.4.6
|
||||
# via einx
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
@ -261,6 +272,8 @@ numpy==1.26.4
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
@ -283,6 +296,7 @@ numpy==1.26.4
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -455,6 +469,7 @@ pyyaml==6.0.2
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
# vocos
|
||||
ray[adag]==2.40.0
|
||||
# via -r requirements-test.in
|
||||
redis==5.2.0
|
||||
@ -517,6 +532,7 @@ scipy==1.13.1
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
@ -540,7 +556,9 @@ sqlitedict==2.1.0
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
# via
|
||||
# einx
|
||||
# torch
|
||||
tabledata==1.3.3
|
||||
# via pytablewriter
|
||||
tabulate==0.9.0
|
||||
@ -568,12 +586,21 @@ torch==2.5.1
|
||||
# -r requirements-test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# encodec
|
||||
# lm-eval
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchvision==0.20.1
|
||||
# via timm
|
||||
tqdm==4.66.6
|
||||
@ -584,13 +611,15 @@ tqdm==4.66.6
|
||||
# lm-eval
|
||||
# nltk
|
||||
# peft
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.47.0
|
||||
transformers==4.48.2
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# peft
|
||||
@ -615,6 +644,7 @@ typing-extensions==4.12.2
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
@ -626,6 +656,10 @@ urllib3==2.2.3
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements-test.in
|
||||
vocos==0.1.0
|
||||
# via -r requirements-test.in
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
@ -638,4 +672,4 @@ zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
# setuptools
|
||||
@ -10,17 +10,14 @@ wheel
|
||||
jinja2
|
||||
ray[default]
|
||||
|
||||
# Install torch, torch_xla
|
||||
# Install torch_xla
|
||||
--pre
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
# Note: This torch whl can be slightly different from the official torch nightly whl
|
||||
# since they are not built on the same commit (but on the same day). This difference may cause C++ undefined symbol issue
|
||||
# if some change between the 2 commits introduce some C++ API change.
|
||||
# Here we install the exact torch whl from which torch_xla is built from, to avoid potential C++ undefined symbol issue.
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.6.0.dev20241216+cpu
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
5
setup.py
Executable file → Normal file
5
setup.py
Executable file → Normal file
@ -598,10 +598,7 @@ if _is_hip():
|
||||
|
||||
if _is_cuda():
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
|
||||
# FA3 requires CUDA 12.0 or later
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
|
||||
@ -25,32 +25,27 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int, distributed_executor_backend: str):
|
||||
def api_server(tokenizer_pool_size: int, worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
str(script_path),
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_executor_backend,
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m", "--host",
|
||||
"127.0.0.1", "--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size)
|
||||
]
|
||||
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int,
|
||||
distributed_executor_backend: str):
|
||||
worker_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
||||
@ -29,10 +29,10 @@ def check_settings():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_executor_backend() -> str:
|
||||
# When SPMD worker is used, use distributed_executor_backend="ray"
|
||||
def worker_use_ray() -> bool:
|
||||
# When SPMD worker is used, use ray_use_worker=True
|
||||
# to test delta input optimization works with preemption.
|
||||
return "ray" if envs.VLLM_USE_RAY_SPMD_WORKER else "mp"
|
||||
return envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -47,7 +47,7 @@ def test_chunked_prefill_recompute(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
distributed_executor_backend: str,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Ensure that chunked prefill works with preemption."""
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
@ -66,7 +66,7 @@ def test_chunked_prefill_recompute(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
worker_use_ray=worker_use_ray,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
@ -93,7 +93,7 @@ def test_preemption(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
distributed_executor_backend: str,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""By default, recompute preemption is enabled"""
|
||||
|
||||
@ -104,7 +104,7 @@ def test_preemption(
|
||||
model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
@ -144,7 +144,7 @@ def test_preemption_infeasible(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
distributed_executor_backend: str,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
"""Verify infeasible preemption request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
@ -159,7 +159,7 @@ def test_preemption_infeasible(
|
||||
# ignored instead of hanging forever.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
|
||||
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
worker_use_ray=worker_use_ray,
|
||||
) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
@ -20,7 +20,7 @@ TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
["--enable-chunked-prefill"], # Chunked
|
||||
@ -66,14 +66,21 @@ def run_test(more_args):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="V1 currently only supported on CUDA")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 currently only supported on CUDA and TPU")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
run_test([])
|
||||
more_args = []
|
||||
|
||||
# Limit compilation time for V1
|
||||
if current_platform.is_tpu():
|
||||
more_args = ["--max-num-seqs", "64"]
|
||||
|
||||
run_test(more_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from vllm.entrypoints.openai.protocol import RerankResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
})
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.results[0].relevance_score >= 0.9
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.", "Cross-encoder models are neat"
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": 2
|
||||
})
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.results[0].relevance_score >= 0.9
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
query = "What is the capital of France?" * 100
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents
|
||||
})
|
||||
assert rerank_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
rerank_response.text
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -22,9 +21,6 @@ INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}}
|
||||
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
||||
|
||||
INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}"""
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
@ -106,36 +102,3 @@ def test_embeddings():
|
||||
# Ensure that the output format conforms to the openai api.
|
||||
# Validation should throw if the schema is wrong.
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
|
||||
def test_score():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write(INPUT_SCORE_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.run_batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
], )
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
for line in contents.strip().split("\n"):
|
||||
# Ensure that the output format conforms to the openai api.
|
||||
# Validation should throw if the schema is wrong.
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
# Ensure that there is no error in the response.
|
||||
line_dict = json.loads(line)
|
||||
assert isinstance(line_dict, dict)
|
||||
assert line_dict["error"] is None
|
||||
|
||||
@ -10,7 +10,12 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
# Will be used on tests to compare prompt input length
|
||||
"--max-model-len",
|
||||
"100"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -103,116 +103,6 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# lower than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 10 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# higher than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 200 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 100
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
|
||||
11
tests/kernels/test_cascade_flash_attn.py
Executable file → Normal file
11
tests/kernels/test_cascade_flash_attn.py
Executable file → Normal file
@ -6,9 +6,7 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
|
||||
merge_attn_states)
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported)
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 192, 256]
|
||||
@ -93,9 +91,10 @@ def test_cascade(
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -11,8 +11,6 @@ from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
@ -43,10 +41,34 @@ capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
|
||||
@ -1,214 +0,0 @@
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||
"""
|
||||
from typing import Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||
big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 512),
|
||||
(16, 256, 512),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 512),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 512),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype)
|
||||
baseline = F.linear(a, b.T)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
@ -4,10 +4,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -97,9 +95,10 @@ def test_flash_attn_with_paged_kv(
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
@ -183,9 +182,11 @@ def test_varlen_with_paged_kv(
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
|
||||
@ -1,126 +0,0 @@
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _Backend, _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
def test_mha_attn_platform(device: str):
|
||||
"""
|
||||
Test the attention selector between different platform and device.
|
||||
"""
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
else:
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Native implementation of scaled dot product attention without mask:
|
||||
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
||||
- attn_mask: [batch_size, seq_len, seq_len]
|
||||
"""
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
||||
return out
|
||||
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
SEQ_LENS = [1]
|
||||
NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
CUDA_DEVICES = ["cuda"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_forward(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
q = torch.randn(batch_size, seq_len, num_heads * head_size)
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(num_heads,
|
||||
head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
||||
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
if num_queries_per_kv > 1:
|
||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||
|
||||
ref_output = ref_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
torch.testing.assert_close(output, ref_output)
|
||||
134
tests/kernels/test_semi_structured.py
Normal file
134
tests/kernels/test_semi_structured.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||
"""
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||
big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
@ -5,7 +5,7 @@ import random
|
||||
import unittest
|
||||
from numbers import Number
|
||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||
Type, Union)
|
||||
Union)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -1100,28 +1100,3 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
|
||||
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
@ -16,8 +16,7 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: List[str] = [
|
||||
"--distributed-executor-backend",
|
||||
"ray",
|
||||
"--worker-use-ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
"--swap-space",
|
||||
|
||||
@ -313,10 +313,8 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="2of4 Sparse is not yet supported on this GPU type."
|
||||
)
|
||||
reason="Sparse FP8 is not yet supported on this GPU type.")
|
||||
@pytest.mark.parametrize(
|
||||
"args_2of4",
|
||||
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")])
|
||||
|
||||
@ -34,4 +34,4 @@ run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/spec_decode
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
run_mypy vllm/v1
|
||||
@ -26,4 +26,4 @@ class ImageAsset:
|
||||
"""
|
||||
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
|
||||
s3_prefix=VLM_IMAGES_DIR)
|
||||
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||
return torch.load(image_path, map_location="cpu")
|
||||
|
||||
14
vllm/attention/backends/flash_attn.py
Executable file → Normal file
14
vllm/attention/backends/flash_attn.py
Executable file → Normal file
@ -18,20 +18,17 @@ from vllm.attention.backends.utils import (
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@ -655,11 +652,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
if not is_fa_version_supported(self.fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
self.fa_version,
|
||||
fa_version_unsupported_reason(self.fa_version))
|
||||
|
||||
assert is_fa_version_supported(self.fa_version)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -210,9 +210,6 @@ class MultiHeadAttention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
@ -224,8 +221,7 @@ class MultiHeadAttention(nn.Module):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
@ -235,7 +231,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA2
|
||||
bsz, q_len, _ = query.size()
|
||||
kv_len = key.size(1)
|
||||
|
||||
@ -243,11 +239,6 @@ class MultiHeadAttention(nn.Module):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
|
||||
@ -910,18 +910,12 @@ class ModelConfig:
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"max_new_tokens",
|
||||
]
|
||||
if any(p in config for p in available_params):
|
||||
diff_sampling_param = {
|
||||
p: config.get(p)
|
||||
for p in available_params if config.get(p) is not None
|
||||
}
|
||||
# Huggingface definition of max_new_tokens is equivalent
|
||||
# to vLLM's max_tokens
|
||||
if "max_new_tokens" in diff_sampling_param:
|
||||
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
|
||||
"max_new_tokens")
|
||||
else:
|
||||
diff_sampling_param = {}
|
||||
return diff_sampling_param
|
||||
@ -1233,6 +1227,9 @@ class ParallelConfig:
|
||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
|
||||
# Deprecated, use distributed_executor_backend instead.
|
||||
worker_use_ray: Optional[bool] = None
|
||||
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
# parallel and large models.
|
||||
@ -1286,6 +1283,13 @@ class ParallelConfig:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
if self.worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif not self.use_ray:
|
||||
raise ValueError(f"worker-use-ray can't be used with "
|
||||
f"distributed executor backend "
|
||||
f"'{self.distributed_executor_backend}'.")
|
||||
ray_only_devices = ["tpu"]
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.device_type in ray_only_devices
|
||||
|
||||
@ -100,6 +100,7 @@ class EngineArgs:
|
||||
kv_cache_dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
# Note: Specifying a custom executor backend by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
@ -388,6 +389,10 @@ class EngineArgs:
|
||||
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
|
||||
'only supports Ray for distributed inference.')
|
||||
|
||||
parser.add_argument(
|
||||
'--worker-use-ray',
|
||||
action='store_true',
|
||||
help='Deprecated, use ``--distributed-executor-backend=ray``.')
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
@ -939,9 +944,7 @@ class EngineArgs:
|
||||
"Defaults to None, will use the default generation config in vLLM. "
|
||||
"If set to 'auto', the generation config will be automatically "
|
||||
"loaded from model. If set to a folder path, the generation config "
|
||||
"will be loaded from the specified folder path. If "
|
||||
"`max_new_tokens` is specified, then it sets a server-wide limit "
|
||||
"on the number of output tokens for all requests.")
|
||||
"will be loaded from the specified folder path.")
|
||||
|
||||
parser.add_argument("--enable-sleep-mode",
|
||||
action="store_true",
|
||||
@ -1068,6 +1071,7 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
worker_use_ray=self.worker_use_ray,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
tokenizer_pool_config=TokenizerPoolConfig.create_config(
|
||||
@ -1275,22 +1279,11 @@ class EngineArgs:
|
||||
self.enable_chunked_prefill = True
|
||||
# When no user override, set the default values based on the usage
|
||||
# context.
|
||||
# Use different default values for different hardware.
|
||||
from vllm.platforms import current_platform
|
||||
device_name = current_platform.get_device_name().lower()
|
||||
if "h100" in device_name or "h200" in device_name:
|
||||
# For H100 and H200, we use larger default values.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 16384,
|
||||
UsageContext.OPENAI_API_SERVER: 8192,
|
||||
}
|
||||
else:
|
||||
# TODO(woosuk): Tune the default values for other hardware.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
|
||||
# TODO(woosuk): Tune the default values for different hardware.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
|
||||
@ -259,6 +259,21 @@ class Metrics:
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = self._gauge_cls(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
# Deprecated in favor of vllm:generation_tokens_total
|
||||
self.gauge_avg_generation_throughput = self._gauge_cls(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum",
|
||||
)
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
|
||||
@ -620,6 +635,20 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self._log_histogram(self.metrics.histogram_max_tokens_request,
|
||||
stats.max_tokens_requests)
|
||||
|
||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||
generation_throughput: float) -> None:
|
||||
# Logs metrics to prometheus that are computed every logging_interval.
|
||||
# Support legacy gauge metrics that make throughput calculations on
|
||||
# the vLLM side. Moving forward, we should use counters like
|
||||
# counter_prompt_tokens, counter_generation_tokens
|
||||
# Which log raw data and calculate summaries using rate() on the
|
||||
# grafana/prometheus side. See
|
||||
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
|
||||
self.metrics.gauge_avg_prompt_throughput.labels(
|
||||
**self.labels).set(prompt_throughput)
|
||||
self.metrics.gauge_avg_generation_throughput.labels(
|
||||
**self.labels).set(generation_throughput)
|
||||
|
||||
def log(self, stats: Stats):
|
||||
"""Logs to prometheus and tracked stats every iteration."""
|
||||
# Log to prometheus.
|
||||
@ -635,6 +664,20 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# Log locally every local_interval seconds.
|
||||
if local_interval_elapsed(stats.now, self.last_local_log,
|
||||
self.local_interval):
|
||||
# Compute summary metrics for tracked stats (and log them
|
||||
# to promethus if applicable).
|
||||
prompt_throughput = get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
generation_throughput = get_throughput(
|
||||
self.num_generation_tokens,
|
||||
now=stats.now,
|
||||
last_log=self.last_local_log)
|
||||
|
||||
self._log_prometheus_interval(
|
||||
prompt_throughput=prompt_throughput,
|
||||
generation_throughput=generation_throughput)
|
||||
|
||||
if self.spec_decode_metrics is not None:
|
||||
self._log_gauge(
|
||||
self.metrics.gauge_spec_decode_draft_acceptance_rate,
|
||||
|
||||
@ -56,7 +56,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
@ -69,7 +68,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
@ -308,10 +306,6 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
|
||||
return request.app.state.jinaai_serving_reranking
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@ -508,40 +502,6 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API")
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client"
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
@ -552,10 +512,7 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
"default": (ScoreRequest, create_score),
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
@ -802,12 +759,6 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.jinaai_serving_reranking = JinaAIServingRerank(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
@ -380,17 +380,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
) -> BeamSearchParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
@ -410,16 +406,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
@ -749,17 +740,13 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
@ -777,16 +764,11 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
@ -1018,52 +1000,6 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
query: str
|
||||
documents: List[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-rerank-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-rerank-pooling-params
|
||||
|
||||
# doc: begin-rerank-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: List[RerankResult]
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
@ -1283,7 +1219,7 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest]
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
@ -1294,8 +1230,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
|
||||
ScoreResponse]] = None
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
|
||||
@ -16,14 +16,12 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
ScoreResponse)
|
||||
EmbeddingResponse, ErrorResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -169,8 +167,7 @@ async def run_request(serving_engine_func: Callable,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
|
||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
@ -242,12 +239,6 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embed" else None
|
||||
openai_serving_scores = (OpenAIServingScores(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "score" else None)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
@ -288,28 +279,14 @@ async def main(args):
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/score":
|
||||
handler_fn = (None if openai_serving_scores is None else
|
||||
openai_serving_scores.create_score)
|
||||
if handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
|
||||
"are supported in the batch endpoint.",
|
||||
error_msg="Only /v1/chat/completions and "
|
||||
"/v1/embeddings are supported in the batch endpoint.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
|
||||
@ -26,8 +26,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
ErrorResponse, ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -205,9 +204,9 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
if isinstance(
|
||||
request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
|
||||
@ -1,206 +0,0 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JinaAIServingRerank(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
query = request.query
|
||||
documents = request.documents
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
for doc in documents:
|
||||
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=query,
|
||||
text_pair=doc,
|
||||
**tokenization_kwargs)
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_rerank_response(
|
||||
final_res_batch_checked, request_id, model_name, documents,
|
||||
top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: List[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: List[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
@ -273,8 +273,7 @@ class LoRAModel(AdapterModel):
|
||||
new_embeddings_tensor_path)
|
||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||
embeddings = torch.load(new_embeddings_bin_file_path,
|
||||
map_location=device,
|
||||
weights_only=True)
|
||||
map_location=device)
|
||||
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id()
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,21 +1,21 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -23,10 +23,10 @@
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -34,10 +34,10 @@
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -48,7 +48,7 @@
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -56,10 +56,10 @@
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -67,8 +67,19 @@
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
@ -76,48 +87,37 @@
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 32,
|
||||
"kpack": 2
|
||||
},
|
||||
"256": {
|
||||
@ -129,7 +129,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -150,7 +150,7 @@
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"matrix_instr_nonkdim": 32,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
@ -184,7 +184,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -195,6 +195,6 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,200 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,10 +1,10 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -19,14 +19,14 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -34,52 +34,41 @@
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"48": {
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
@ -87,23 +76,34 @@
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -112,24 +112,24 @@
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
@ -151,7 +151,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -162,7 +162,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -184,7 +184,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -195,6 +195,6 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,200 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 32,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,21 +1,21 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -23,10 +23,10 @@
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -34,7 +34,7 @@
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
@ -52,9 +52,31 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
@ -65,28 +87,6 @@
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
@ -101,40 +101,40 @@
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
@ -151,7 +151,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -173,7 +173,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -195,6 +195,6 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,200 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 32,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,7 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
@ -12,54 +12,54 @@
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"24": {
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -68,7 +68,7 @@
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
@ -78,32 +78,32 @@
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
@ -112,18 +112,18 @@
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
@ -140,7 +140,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -151,7 +151,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -173,7 +173,7 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
@ -187,7 +187,7 @@
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
@ -195,6 +195,6 @@
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
"kpack": 1
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -1,200 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@ -9,7 +9,6 @@ from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationType)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -28,8 +27,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
@ -82,8 +79,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
if scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
@ -345,10 +340,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
def get_scheme(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
@ -358,7 +353,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
We first check whether a layer is in the ignore group and use
|
||||
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
|
||||
|
||||
We then detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for infernece.
|
||||
"""
|
||||
@ -396,13 +394,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme):
|
||||
# FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels
|
||||
# currently produce bad output in some cases
|
||||
if weight_quant is None:
|
||||
logger.warning_once(
|
||||
"CompressedTensors24 scheme is disabled for the w16a16 "
|
||||
"case. Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
scheme = CompressedTensors24(quantized=weight_quant is not None
|
||||
|
||||
@ -93,7 +93,7 @@ def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
@ -381,9 +381,7 @@ def np_cache_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
@ -449,7 +447,7 @@ def pt_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
yield from state.items()
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -481,14 +481,14 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
image_token_id = vocab["<image>"]
|
||||
image_token_id = vocab["image"]
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[bos_token_id],
|
||||
target="</s>",
|
||||
replacement=PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
|
||||
@ -348,7 +348,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config # Required by MixtralForCausalLM
|
||||
|
||||
self.model = GraniteMoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
@ -135,7 +135,7 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
"vllm.v1.worker.gpu_worker.GPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ class _Backend(enum.Enum):
|
||||
FLASHINFER = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
@ -30,10 +31,16 @@ class TpuPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
if (selected_backend != _Backend.PALLAS
|
||||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
logger.info("Using Pallas backend.")
|
||||
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
if use_v1:
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
else:
|
||||
logger.info("Using Pallas backend.")
|
||||
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
@ -45,7 +52,7 @@ class TpuPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
return not envs.VLLM_USE_V1
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
@ -60,11 +67,11 @@ class TpuPlatform(Platform):
|
||||
cache_config.block_size = 16
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.NO_COMPILATION:
|
||||
# TPU does not support NO_COMPILATION
|
||||
|
||||
# TPU only supports DYNAMO_ONCE compilation level
|
||||
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
||||
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
assert compilation_config.level < CompilationLevel.PIECEWISE,\
|
||||
"TPU does not support Inductor."
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
@ -72,10 +79,6 @@ class TpuPlatform(Platform):
|
||||
assert vllm_config.speculative_config is None, \
|
||||
"TPU does not support speculative decoding"
|
||||
|
||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
|
||||
"Chunked prefill is not yet supported for TPU backend")
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
@ -85,8 +88,27 @@ class TpuPlatform(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.tpu_worker.TPUWorker"
|
||||
|
||||
# Adjust scheduler config for V1
|
||||
# TODO: Add support for these
|
||||
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
|
||||
logger.warning("[V1][TPU] Disable prefix caching")
|
||||
vllm_config.cache_config.enable_prefix_caching = False
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
return False
|
||||
|
||||
@ -89,7 +89,6 @@ def load_peft_weights(model_id: str,
|
||||
adapters_weights = safe_load_file(filename, device=device)
|
||||
else:
|
||||
adapters_weights = torch.load(filename,
|
||||
map_location=torch.device(device),
|
||||
weights_only=True)
|
||||
map_location=torch.device(device))
|
||||
|
||||
return adapters_weights
|
||||
|
||||
@ -145,8 +145,7 @@ class S3Model:
|
||||
return
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(self.dir,
|
||||
file.removeprefix(base_dir))
|
||||
destination_file = self.dir + file.removeprefix(base_dir)
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.s3.download_file(bucket_name, file, destination_file)
|
||||
|
||||
11
vllm/v1/attention/backends/flash_attn.py
Executable file → Normal file
11
vllm/v1/attention/backends/flash_attn.py
Executable file → Normal file
@ -10,15 +10,11 @@ import triton.language as tl
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
is_fa_version_supported)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -147,11 +143,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
if not is_fa_version_supported(self.fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
self.fa_version,
|
||||
fa_version_unsupported_reason(self.fa_version))
|
||||
|
||||
assert is_fa_version_supported(self.fa_version)
|
||||
|
||||
def forward(
|
||||
|
||||
351
vllm/v1/attention/backends/pallas.py
Normal file
351
vllm/v1/attention/backends/pallas.py
Normal file
@ -0,0 +1,351 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
src_indices, dst_indices = src_to_dists
|
||||
for k_cache, v_cache in kv_caches:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata(AttentionMetadata):
|
||||
|
||||
# Currently, input sequences can only contain all prefills
|
||||
# or all decoding.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
effective_query_lens: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
assert self.num_decode_tokens == 0
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.block_tables is not None
|
||||
assert self.context_lens is not None
|
||||
return self
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
if logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"Attention logits soft-capping is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||
or tpu_env.get("TYPE", None)
|
||||
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||
assert tpu_type is not None
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
else:
|
||||
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||
# megacore mode will be None.
|
||||
self.megacore_mode = "batch"
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||
with shape [0] for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
|
||||
if attn_metadata is None:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata.block_tables is None:
|
||||
# Prefill without paged KV cache.
|
||||
assert seq_len % 16 == 0, (
|
||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||
f"multiple of 16 but got {seq_len}")
|
||||
|
||||
# Handle GQA/MQA.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
key = key.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
value = value.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
# FlashAttention kernel requires the input shape to be
|
||||
# [batch_size, num_heads, seq_len, d_model]
|
||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||
# Permute the input to match the required format.
|
||||
output = torch.ops.xla.flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Prefill with paged KV cache.
|
||||
# TODO(woosuk): Tune the below knobs.
|
||||
num_kv_pages_per_compute_block = 16
|
||||
num_queries_per_compute_block = 16
|
||||
assert seq_len % num_queries_per_compute_block == 0
|
||||
output = torch.ops.xla.multi_queries_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.effective_query_lens,
|
||||
num_kv_pages_per_compute_block,
|
||||
num_queries_per_compute_block,
|
||||
use_kernel=True,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache[0].numel() > 0
|
||||
query = query.squeeze(dim=1)
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
|
||||
assert attn_metadata.block_tables is not None
|
||||
assert attn_metadata.context_lens is not None
|
||||
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||
# block table in SMEM. Therefore, if the block table is too large,
|
||||
# the kernel compilation will fail. To avoid this, we split the
|
||||
# batch dimension into smaller chunks and run the kernel multiple
|
||||
# times.
|
||||
MAX_SMEM_USAGE = 512 * 1024
|
||||
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||
|
||||
if batch_size <= max_num_seq:
|
||||
output = paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
)
|
||||
else:
|
||||
chunk_size = max_num_seq
|
||||
# Make sure the chunk size is a multiple of 2.
|
||||
chunk_size = chunk_size // 2 * 2
|
||||
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||
|
||||
output = torch.empty_like(query)
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||
# compilation error. Instead, we rely on the slice operation
|
||||
# to handle the out-of-bound case.
|
||||
# chunk_end = min(chunk_end, batch_size)
|
||||
chunk_output = paged_attention(
|
||||
query[chunk_start:chunk_end],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
)
|
||||
output[chunk_start:chunk_end] = chunk_output
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 2)
|
||||
value = value.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
batch_size = query.shape[0]
|
||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = megacore_mode
|
||||
|
||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||
# "xla::paged_attention() Expected a value of type 'str' for
|
||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||
if megacore_mode is not None:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
)
|
||||
else:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
)
|
||||
return output
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: torch.Tensor
|
||||
sampled_token_ids: List[int]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
@ -58,8 +58,7 @@ class Request:
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||
|
||||
# Cache the computed kv block hashes of the request to avoid
|
||||
# recomputing.
|
||||
|
||||
@ -50,8 +50,9 @@ class Sampler(nn.Module):
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
sampled_token_ids=sampled.tolist(),
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
|
||||
@ -57,6 +57,14 @@ class BlockTable:
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks_src = self.num_blocks_per_row[src]
|
||||
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
@ -72,7 +72,7 @@ class InputBatch:
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
@ -436,3 +436,77 @@ class InputBatch:
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return len(self.prompt_logprob_reqs) == 0
|
||||
|
||||
|
||||
def swap_positions(b: InputBatch, id_1, id_2):
|
||||
assert id_1 != id_2
|
||||
req_id_1 = b.req_ids[id_1]
|
||||
req_id_2 = b.req_ids[id_2]
|
||||
assert req_id_1 is not None
|
||||
assert req_id_2 is not None
|
||||
assert id_1 == b.req_id_to_index[req_id_1]
|
||||
assert id_2 == b.req_id_to_index[req_id_2]
|
||||
|
||||
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
|
||||
b.req_id_to_index[req_id_1], b.req_id_to_index[
|
||||
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
|
||||
|
||||
ids = [id_1, id_2]
|
||||
rev_ids = [id_2, id_1]
|
||||
b.num_tokens[ids] = b.num_tokens[rev_ids]
|
||||
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
|
||||
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
|
||||
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
|
||||
|
||||
b.block_table.swap_row(id_1, id_2)
|
||||
|
||||
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
|
||||
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
|
||||
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
|
||||
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
|
||||
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
|
||||
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
|
||||
|
||||
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
|
||||
id_1]
|
||||
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
|
||||
id_2], b.stop_token_ids[id_1]
|
||||
|
||||
gen_1 = b.generators.pop(id_1, None)
|
||||
gen_2 = b.generators.pop(id_2, None)
|
||||
if gen_1 is not None:
|
||||
b.generators[id_2] = gen_1
|
||||
if gen_2 is not None:
|
||||
b.generators[id_1] = gen_2
|
||||
|
||||
|
||||
def ensure_decodes_first(b: InputBatch):
|
||||
num_reqs = b.num_reqs
|
||||
while True:
|
||||
# Find the first prompt index
|
||||
first_prompt_index = None
|
||||
for i in range(num_reqs):
|
||||
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
|
||||
first_prompt_index = i
|
||||
break
|
||||
if first_prompt_index is None:
|
||||
break
|
||||
|
||||
# Find the last decode index
|
||||
last_decode_index = None
|
||||
for i in reversed(range(num_reqs)):
|
||||
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
|
||||
last_decode_index = i
|
||||
break
|
||||
if last_decode_index is None:
|
||||
break
|
||||
|
||||
# Sanity
|
||||
assert first_prompt_index != last_decode_index
|
||||
|
||||
# Check if done
|
||||
if first_prompt_index > last_decode_index:
|
||||
break
|
||||
|
||||
# Swap
|
||||
swap_positions(b, first_prompt_index, last_decode_index)
|
||||
|
||||
@ -5,32 +5,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.utils import DeviceMemoryProfiler, cdiv
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@ -38,87 +29,17 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# KV caches for forward pass
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
@ -171,8 +92,7 @@ class GPUModelRunner:
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||
self.max_model_len,
|
||||
self.max_num_tokens),
|
||||
self.max_model_len),
|
||||
dtype=np.int32)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
@ -203,132 +123,6 @@ class GPUModelRunner:
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
self.requests[req_id].mrope_positions, \
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@ -359,15 +153,8 @@ class GPUModelRunner:
|
||||
|
||||
# Get batched arange.
|
||||
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# Equivalent to but faster than:
|
||||
# np.concatenate([np.arange(n) for n in num_scheduled_tokens])
|
||||
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||
arange = np.concatenate(
|
||||
[self.arange_np[:n] for n in num_scheduled_tokens])
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
@ -414,7 +201,8 @@ class GPUModelRunner:
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
np.cumsum(num_scheduled_tokens,
|
||||
out=self.query_start_loc_np[1:num_reqs + 1])
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
@ -618,6 +406,8 @@ class GPUModelRunner:
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
assert self.model is not None
|
||||
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
@ -705,15 +495,14 @@ class GPUModelRunner:
|
||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||
return encoder_outputs
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
assert self.model is not None
|
||||
|
||||
self.update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
@ -782,10 +571,10 @@ class GPUModelRunner:
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
@ -794,10 +583,10 @@ class GPUModelRunner:
|
||||
assert seq_len <= req_state.num_tokens
|
||||
if seq_len == req_state.num_tokens:
|
||||
# Append the sampled token to the output token ids.
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
# OPTIMIZATION: Priming the state updates for later updates.
|
||||
req_state.output_token_ids.append(0)
|
||||
request_seq_lens.append((i, req_state, seq_len))
|
||||
req_state.output_token_ids.append(token_id)
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
@ -806,21 +595,6 @@ class GPUModelRunner:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
|
||||
# Update with the actual token ids
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids[-1] = token_id
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
else:
|
||||
@ -830,6 +604,12 @@ class GPUModelRunner:
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
@ -849,14 +629,15 @@ class GPUModelRunner:
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
def dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
exec_mode: Optional[ExecutionMode] = None,
|
||||
) -> torch.Tensor:
|
||||
model = self.model
|
||||
if kv_caches is None:
|
||||
kv_caches = self.kv_caches
|
||||
assert self.model is not None
|
||||
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
@ -867,7 +648,7 @@ class GPUModelRunner:
|
||||
positions = self.mrope_positions[:, :num_tokens] \
|
||||
if self.model_config.uses_mrope \
|
||||
else self.positions[:num_tokens]
|
||||
hidden_states = model(
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
@ -877,6 +658,7 @@ class GPUModelRunner:
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
assert self.model is not None
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
@ -982,7 +764,7 @@ class GPUModelRunner:
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
|
||||
hidden_states = self.dummy_run(dummy_kv_caches, self.max_num_tokens)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
logits = logits[:self.max_num_tokens]
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
@ -1008,8 +790,8 @@ class GPUModelRunner:
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self.dummy_run(None, num_tokens)
|
||||
self.dummy_run(None, num_tokens)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
@ -1052,38 +834,3 @@ class GPUModelRunner:
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
@ -15,20 +13,17 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase, check_if_gpu_supports_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
|
||||
class Worker:
|
||||
class GPUWorker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -38,46 +33,8 @@ class Worker:
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
super().__init__(vllm_config, local_rank, rank,
|
||||
distributed_init_method)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
@ -97,31 +54,39 @@ class Worker:
|
||||
allocator.wake_up()
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
assert self.device_config.device.type == "cuda"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
init_cuda_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@ -139,6 +104,7 @@ class Worker:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
assert self.model_runner is not None
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -160,6 +126,7 @@ class Worker:
|
||||
_, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
assert self.model_runner is not None
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
@ -191,9 +158,6 @@ class Worker:
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
@ -203,9 +167,12 @@ class Worker:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
assert self.model_runner is not None
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
assert self.model_runner is not None
|
||||
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
@ -217,44 +184,32 @@ class Worker:
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
self.model_runner.dummy_run(None, size)
|
||||
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
assert self.model_runner is not None
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.rank == 0 else None
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
def init_cuda_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
@ -264,21 +219,22 @@ def init_worker_distributed_environment(
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
# TODO: Remove
|
||||
# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# # Check if the GPU supports the dtype.
|
||||
# if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
# if not current_platform.has_device_capability(80):
|
||||
# capability = current_platform.get_device_capability()
|
||||
# gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
# if capability is None:
|
||||
# compute_str = "does not have a compute capability"
|
||||
# else:
|
||||
# version_str = capability.as_version_str()
|
||||
# compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the"
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
# raise ValueError(
|
||||
# "Bfloat16 is only supported on GPUs with compute capability "
|
||||
# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
# "You can use float16 instead by explicitly setting the"
|
||||
# "`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
307
vllm/v1/worker/model_runner_base.py
Normal file
307
vllm/v1/worker/model_runner_base.py
Normal file
@ -0,0 +1,307 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ExecutionMode(enum.Enum):
|
||||
PREFILL = enum.auto()
|
||||
DECODE = enum.auto()
|
||||
PREFIX_PREFILL = enum.auto()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||
|
||||
|
||||
class ModelRunnerBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.device_config = vllm_config.device_config
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
self.model: Optional[nn.Module] = None
|
||||
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
)
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
def update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the requests from the persistent batch.
|
||||
stopped_req_ids = set().union(
|
||||
scheduler_output.preempted_req_ids,
|
||||
scheduler_output.finished_req_ids,
|
||||
)
|
||||
removed_req_indices: List[int] = []
|
||||
for req_id in stopped_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
|
||||
# Update the states of the running requests.
|
||||
for req_data in scheduler_output.scheduled_running_reqs:
|
||||
req_id = req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
# Update the num_computed_tokens.
|
||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
|
||||
# Update the block table.
|
||||
num_new_blocks = len(req_data.new_block_ids)
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt=new_req_data.prompt,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
for mm_input in self.requests[req_id].mm_inputs:
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.extend(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
if mm_input.get("video_grid_thw") is not None:
|
||||
video_grid_thw.extend(
|
||||
mm_input["video_grid_thw"].tolist())
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
|
||||
self.requests[req_id].mrope_positions, \
|
||||
self.requests[req_id].mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.requests[req_id].prompt_token_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the cached states of the resumed requests.
|
||||
for res_req_data in scheduler_output.scheduled_resumed_reqs:
|
||||
req_id = res_req_data.req_id
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.block_ids = res_req_data.block_ids
|
||||
req_state.num_computed_tokens = res_req_data.num_computed_tokens
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
assert self.model is not None
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: KVCacheSpec = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_model(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
seq_len: Optional[int] = None,
|
||||
exec_mode: Optional[ExecutionMode] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def profile_run(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
raise NotImplementedError()
|
||||
888
vllm/v1/worker/tpu_model_runner.py
Normal file
888
vllm/v1/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,888 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
# TPU XLA related
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch,
|
||||
ensure_decodes_first)
|
||||
from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptDecodeInfo:
|
||||
prompt_req_ids: List[str]
|
||||
decode_req_ids: List[str]
|
||||
prompt_scheduled_tokens: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptData:
|
||||
input_tokens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
attn_metadata: PallasMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeData:
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional[PallasMetadata] = None
|
||||
|
||||
|
||||
class TPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# KV caches for forward pass
|
||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# Cached torch/numpy tensors
|
||||
self.num_swaps = 2
|
||||
self.cur_swap_id = 0
|
||||
self.input_ids_cpu = []
|
||||
self.input_ids_np = []
|
||||
self.input_positions_cpu = []
|
||||
self.input_positions_np = []
|
||||
self.slot_mapping_cpu = []
|
||||
self.slot_mapping_np = []
|
||||
self.prompt_context_lens_cpu = []
|
||||
self.prompt_effective_query_lens_cpu = []
|
||||
self.decode_context_lens_cpu = []
|
||||
self.decode_context_lens_np = []
|
||||
for _ in range(self.num_swaps):
|
||||
self.input_ids_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
|
||||
|
||||
self.input_positions_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_positions_np.append(
|
||||
self.input_positions_cpu[-1].numpy())
|
||||
|
||||
self.slot_mapping_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu"))
|
||||
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
|
||||
|
||||
self.prompt_context_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
self.prompt_effective_query_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
|
||||
self.decode_context_lens_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.decode_context_lens_np.append(
|
||||
self.decode_context_lens_cpu[-1].numpy())
|
||||
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
|
||||
def swap_step(self):
|
||||
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
|
||||
|
||||
def _get_prompts_and_decodes(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> PromptDecodeInfo:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Traverse decodes first
|
||||
decode_req_ids = []
|
||||
for i in range(num_reqs):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
|
||||
if num_computed_tokens < num_prompt_tokens:
|
||||
# This is prompt
|
||||
break
|
||||
|
||||
# This is decode
|
||||
assert num_scheduled_tokens == 1
|
||||
decode_req_ids.append(req_id)
|
||||
|
||||
# Traverse prompts
|
||||
prompt_req_ids = []
|
||||
prompt_scheduled_tokens = []
|
||||
for i in range(len(decode_req_ids), num_reqs):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
|
||||
# Must be prompt
|
||||
assert num_computed_tokens < num_prompt_tokens
|
||||
|
||||
prompt_req_ids.append(req_id)
|
||||
prompt_scheduled_tokens.append(num_scheduled_tokens)
|
||||
|
||||
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
|
||||
prompt_scheduled_tokens)
|
||||
|
||||
def _prepare_prompt(self, req_index: int,
|
||||
num_scheduled_tokens: int) -> PromptData:
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
|
||||
req_index]
|
||||
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
|
||||
|
||||
# Must be prompt
|
||||
assert num_computed_tokens < num_prompt_tokens
|
||||
|
||||
# Prompt len
|
||||
prompt_len = num_scheduled_tokens
|
||||
padded_prompt_len = _get_padded_prompt_len(prompt_len)
|
||||
assert padded_prompt_len <= self.max_model_len
|
||||
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
# DEBUG
|
||||
# print("_prepare_prompt:")
|
||||
# print(" prompt_len = {}".format(prompt_len))
|
||||
# print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||
# print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||
# print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||
# print(" seq_len = {}".format(seq_len))
|
||||
# print(" padded_seq_len = {}".format(padded_seq_len))
|
||||
|
||||
# Input tokens
|
||||
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
||||
req_index, num_computed_tokens:padded_seq_len]
|
||||
input_tokens_cpu[prompt_len:] = 0
|
||||
|
||||
# DEBUG
|
||||
# print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(num_computed_tokens,
|
||||
self.arange_np[:padded_prompt_len],
|
||||
out=input_positions_np)
|
||||
input_positions_np[prompt_len:] = 0
|
||||
|
||||
# DEBUG
|
||||
# print(" input_positions_np.shape = {} val = {}".format(
|
||||
# input_positions_np.shape, input_positions_np))
|
||||
|
||||
# Slot mapping
|
||||
block_table_np = \
|
||||
self.input_batch.block_table.get_numpy_array()
|
||||
block_numbers_np = block_table_np[req_index, input_positions_np //
|
||||
self.block_size]
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} val = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
# Block table
|
||||
block_table_cpu = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_table_cpu = block_table_cpu[req_index]
|
||||
|
||||
# DEBUG
|
||||
# print(" block_table_cpu = {}".format(block_table_cpu))
|
||||
|
||||
# Context len
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
|
||||
if num_computed_tokens > 0:
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
|
||||
|
||||
# Effective query len
|
||||
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
|
||||
input_positions = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
block_table = block_table_cpu.reshape(1, -1).to(
|
||||
self.device) if block_table_cpu is not None else None
|
||||
|
||||
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
|
||||
self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu[
|
||||
self.cur_swap_id].to(self.device)
|
||||
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" input_tokens.shape = {} val = {}".format(
|
||||
# input_tokens.shape, input_tokens))
|
||||
# print(" input_positions.shape = {} val = {}".format(
|
||||
# input_positions.shape, input_positions))
|
||||
# print(" slot_mapping.shape = {} val = {}".format(
|
||||
# slot_mapping.shape, slot_mapping))
|
||||
# print(" block_table = {}".format(block_table))
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
# print(" effective_query_lens.shape = {} val = {}".format(
|
||||
# effective_query_lens.shape, effective_query_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=effective_query_lens,
|
||||
)
|
||||
|
||||
return PromptData(input_tokens, input_positions, attn_metadata)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
decode_req_ids: List[str],
|
||||
) -> DecodeData:
|
||||
# Batch size
|
||||
batch_size = len(decode_req_ids)
|
||||
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||
assert padded_batch_size <= self.max_model_len
|
||||
|
||||
# Init [0 .. batch_size - 1]
|
||||
req_indices_np = self.arange_np[:padded_batch_size]
|
||||
|
||||
# DEBUG
|
||||
# print("_prepare_decode:")
|
||||
# print(" batch_size = {}".format(batch_size))
|
||||
# print(" padded_batch_size = {}".format(padded_batch_size))
|
||||
# print(" req_indices_np.shape = {} val = {}".format(
|
||||
# req_indices_np.shape, req_indices_np))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
0,
|
||||
out=input_positions_np)
|
||||
input_positions_np[batch_size:] = 0
|
||||
input_positions_cpu = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
|
||||
# DEBUG
|
||||
# print(" input_positions_cpu.shape = {} data = {}".format(
|
||||
# input_positions_cpu.shape, input_positions_cpu))
|
||||
|
||||
# Input tokens
|
||||
token_indices_np = (
|
||||
input_positions_np +
|
||||
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
|
||||
input_tokens_cpu = self.input_ids_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_np),
|
||||
out=input_tokens_cpu)
|
||||
input_tokens_cpu[batch_size:] = 0
|
||||
|
||||
# DEBUG
|
||||
# print(" token_indices_np.shape = {} val = {}".format(
|
||||
# token_indices_np.shape, token_indices_np))
|
||||
# print(" input_tokens_cpu.shape = {} data = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Slot mapping
|
||||
block_table_indices_np = (
|
||||
req_indices_np * self.max_num_blocks_per_req +
|
||||
input_positions_np // self.block_size)
|
||||
|
||||
# DEBUG
|
||||
# print(
|
||||
# " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
|
||||
# .format(block_table_indices_np.shape, block_table_indices_np,
|
||||
# self.max_num_blocks_per_req))
|
||||
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
# DEBUG
|
||||
# print(" block_table_cpu.shape = {} data = {}".format(
|
||||
# block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
|
||||
|
||||
block_numbers_np = block_table_cpu.flatten(
|
||||
)[block_table_indices_np].numpy()
|
||||
|
||||
# DEBUG
|
||||
# print(" block_numbers_np.shape = {} data = {}".format(
|
||||
# block_numbers_np.shape, block_numbers_np))
|
||||
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
# DEBUG
|
||||
# print(" block_offsets_np.shape = {} data = {}".format(
|
||||
# block_offsets_np.shape, block_offsets_np))
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
|
||||
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} data = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
block_table_cpu = block_table_cpu[:padded_batch_size]
|
||||
|
||||
# Context lens
|
||||
context_lens_np = self.decode_context_lens_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
1,
|
||||
out=context_lens_np)
|
||||
context_lens_np[batch_size:] = 0
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
|
||||
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].reshape(-1,
|
||||
1).to(self.device)
|
||||
block_table = block_table_cpu.to(self.device)
|
||||
context_lens = self.decode_context_lens_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].to(self.device)
|
||||
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=padded_batch_size,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_table,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=None,
|
||||
)
|
||||
|
||||
return DecodeData(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
# Update cached state
|
||||
self.update_states(scheduler_output)
|
||||
|
||||
# If necessary, swap decodes/prompts to have all decodes on the start
|
||||
ensure_decodes_first(self.input_batch)
|
||||
|
||||
# Prepare prompts/decodes info
|
||||
pd_info = self._get_prompts_and_decodes(scheduler_output)
|
||||
|
||||
# Init
|
||||
num_prompts = len(pd_info.prompt_req_ids)
|
||||
num_decodes = len(pd_info.decode_req_ids)
|
||||
decode_data = None
|
||||
sampled_token_ids = [0] * self.input_batch.num_reqs
|
||||
|
||||
# Run each prompt individually
|
||||
is_first = True
|
||||
for i in range(num_prompts):
|
||||
req_id = pd_info.prompt_req_ids[i]
|
||||
req_index = num_decodes + i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
||||
prompt_len = num_scheduled_tokens
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
|
||||
|
||||
# Prepare first prompt
|
||||
if is_first:
|
||||
prompt_data = self._prepare_prompt(req_index,
|
||||
num_scheduled_tokens)
|
||||
is_first = False
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(prompt_data.attn_metadata,
|
||||
self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(prompt_data.input_tokens,
|
||||
prompt_data.input_positions,
|
||||
prompt_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# In parallel to TPU execution, prepare the next iteration
|
||||
if i < num_prompts - 1:
|
||||
# There is next prompt => prepare it
|
||||
prompt_data = self._prepare_prompt(
|
||||
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
|
||||
elif i == num_prompts - 1 and num_decodes > 0:
|
||||
# There is next decode => prepare it
|
||||
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||
|
||||
# Update cached state (if prompt is fully done)
|
||||
if seq_len >= len(req_state.prompt_token_ids):
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
selected_token_ids_cpu = selected_token_ids.cpu()
|
||||
|
||||
# Get output token
|
||||
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
# DEBUG
|
||||
# print(
|
||||
# " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||
# .format(token_id, prompt_len, req_id, req_index,
|
||||
# selected_token_ids_cpu))
|
||||
|
||||
# Add output token to the request
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Run decodes (a single batch)
|
||||
if num_decodes > 0:
|
||||
|
||||
# Prepare decode (if was not yet prepared)
|
||||
if decode_data is None:
|
||||
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(decode_data.attn_metadata,
|
||||
self.vllm_config):
|
||||
assert self.model is not None
|
||||
selected_token_ids = self.model(decode_data.input_tokens,
|
||||
decode_data.input_positions,
|
||||
decode_data.attn_metadata,
|
||||
self.kv_caches)
|
||||
|
||||
# Transfer sampled tokens from TPU to CPU
|
||||
decode_token_ids_cpu = selected_token_ids.cpu()
|
||||
# Convert to list
|
||||
decode_token_ids_list = decode_token_ids_cpu.tolist()
|
||||
|
||||
# Update cached state for each decode request
|
||||
for i in range(num_decodes):
|
||||
req_id = pd_info.decode_req_ids[i]
|
||||
req_index = i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + 1
|
||||
|
||||
token_id = decode_token_ids_list[i]
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Create output
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=None,
|
||||
logprobs_cpu=None,
|
||||
)
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||
# process, the ranks can be different from the ranks internally assigned
|
||||
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||
# This is not a problem in linear layers because all-reduce is
|
||||
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||
# determine the order of concatenating the output tensors.
|
||||
# As a workaround, we use the xm's rank assignment only when loading
|
||||
# the embedding weights.
|
||||
xm_tp_rank = xr.global_ordinal()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
model = ModelWrapperV1(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
seq_len: Optional[int] = None,
|
||||
exec_mode: Optional[ExecutionMode] = None,
|
||||
) -> None:
|
||||
assert seq_len is not None
|
||||
assert exec_mode is not None
|
||||
|
||||
exec_mode = ExecutionMode(exec_mode)
|
||||
if exec_mode.is_prefill():
|
||||
seq_len = (seq_len + 15) // 16 * 16
|
||||
token_ids = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
if exec_mode == ExecutionMode.PREFILL:
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=num_tokens,
|
||||
num_prefill_tokens=num_tokens * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
effective_query_lens=None,
|
||||
)
|
||||
|
||||
else:
|
||||
context_lens = torch.ones((num_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
block_tables = torch.zeros(
|
||||
(num_tokens, self.max_num_blocks_per_req),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
effective_query_lens = torch.ones_like(context_lens)
|
||||
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=num_tokens,
|
||||
num_prefill_tokens=num_tokens * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=effective_query_lens,
|
||||
)
|
||||
else:
|
||||
assert seq_len == 1
|
||||
token_ids = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((num_tokens, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
block_tables = torch.zeros(
|
||||
(num_tokens, self.max_num_blocks_per_req),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
context_lens = torch.ones((num_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=num_tokens * seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if exec_mode.is_prefill():
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
self.model(token_ids, position_ids, attn_metadata, kv_caches)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
|
||||
# Prefill
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info(" -- Compilation for prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Prefix prefill
|
||||
if self.scheduler_config.enable_chunked_prefill:
|
||||
logger.info("Compiling the model with different input shapes for "
|
||||
"prefix prefill:")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if (num_tokens
|
||||
>= self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
end = time.time()
|
||||
logger.info(
|
||||
" -- Compilation for prefix prefill done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
# Decode
|
||||
logger.info(
|
||||
"Compiling the model with different input shapes for decode:")
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||
while True:
|
||||
self.dummy_run(self.kv_caches,
|
||||
batch_size,
|
||||
seq_len,
|
||||
exec_mode=ExecutionMode.DECODE)
|
||||
xm.wait_device_ops()
|
||||
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info(" -- Compilation for decode done in %.2f [secs].",
|
||||
end - start)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
if len(kv_cache_config.groups) > 1:
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
|
||||
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
|
||||
class ModelWrapperV1(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||
attn_metadata: The Pallas attention metadata.
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
num_samples: Number of samples to draw from each logits vector.
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
# Skip this in memory profiling at initialization.
|
||||
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indicies = torch.arange(0,
|
||||
num_kv_heads,
|
||||
device=slot_mapping.device,
|
||||
dtype=slot_mapping.dtype)
|
||||
head_indicies *= block_size * num_blocks
|
||||
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||
-1, num_kv_heads)
|
||||
slot_mapping = slot_mapping + head_indicies.view(1, -1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
token_ids,
|
||||
position_ids,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
# Greedy sampling.
|
||||
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
||||
return argmax_token_ids
|
||||
|
||||
|
||||
def _get_padded_prompt_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||
# To meet this requirement in the simplest way, we set the minimal batch
|
||||
# size to 8.
|
||||
if batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
153
vllm/v1/worker/tpu_worker.py
Normal file
153
vllm/v1/worker/tpu_worker.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config, local_rank, rank,
|
||||
distributed_init_method)
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
init_tpu_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
# Device initialization should happen after initializing
|
||||
# the distributed runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
assert self.model_runner is not None
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
||||
|
||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
runner_kv_caches = []
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner.dummy_run(
|
||||
runner_kv_caches,
|
||||
num_tokens=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
exec_mode=ExecutionMode.PREFILL,
|
||||
)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
|
||||
return int(tpu_kv_cache_bytes)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
assert self.model_runner is not None
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
return output if self.rank == 0 else None
|
||||
|
||||
|
||||
def init_tpu_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
173
vllm/v1/worker/worker_base.py
Normal file
173
vllm/v1/worker/worker_base.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""A GPU worker class."""
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
|
||||
|
||||
class WorkerBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
# Initialized by the specific platform
|
||||
self.model_runner: Optional[ModelRunnerBase] = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
assert self.model_runner is not None
|
||||
self.model_runner.load_model()
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
assert self.model_runner is not None
|
||||
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
assert self.model_runner is not None
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
assert self.model_runner is not None
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
assert self.model_runner is not None
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def init_device(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not current_platform.has_device_capability(80):
|
||||
capability = current_platform.get_device_capability()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the"
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
@ -455,6 +455,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
|
||||
is not None)
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.decode_only = True
|
||||
|
||||
# Attention metadata inputs.
|
||||
if self.attn_backend is not None:
|
||||
@ -476,10 +477,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.finished_requests_ids = finished_requests_ids
|
||||
|
||||
# if the current batch is decode-only.
|
||||
# will be set to False if there is any non-decode request.
|
||||
self.decode_only = True
|
||||
|
||||
# Intermediate data (data in CPU before going to GPU) for
|
||||
# the current sequence group.
|
||||
self.inter_data_list: List[
|
||||
|
||||
Reference in New Issue
Block a user