[Docs] Convert rST to MyST (Markdown) (#11145)
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
This commit is contained in:
@ -1,25 +1,24 @@
|
||||
.. _arch_overview:
|
||||
(arch-overview)=
|
||||
|
||||
Architecture Overview
|
||||
======================
|
||||
# Architecture Overview
|
||||
|
||||
This document provides an overview of the vLLM architecture.
|
||||
|
||||
.. contents:: Table of Contents
|
||||
:local:
|
||||
:depth: 2
|
||||
```{contents} Table of Contents
|
||||
:depth: 2
|
||||
:local: true
|
||||
```
|
||||
|
||||
Entrypoints
|
||||
-----------
|
||||
## Entrypoints
|
||||
|
||||
vLLM provides a number of entrypoints for interacting with the system. The
|
||||
following diagram shows the relationship between them.
|
||||
|
||||
.. image:: /assets/design/arch_overview/entrypoints.excalidraw.png
|
||||
:alt: Entrypoints Diagram
|
||||
```{image} /assets/design/arch_overview/entrypoints.excalidraw.png
|
||||
:alt: Entrypoints Diagram
|
||||
```
|
||||
|
||||
LLM Class
|
||||
^^^^^^^^^
|
||||
### LLM Class
|
||||
|
||||
The LLM class provides the primary Python interface for doing offline inference,
|
||||
which is interacting with a model without using a separate model inference
|
||||
@ -27,75 +26,70 @@ server.
|
||||
|
||||
Here is a sample of `LLM` class usage:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
# Define a list of input prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The largest ocean is",
|
||||
]
|
||||
|
||||
# Define a list of input prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The largest ocean is",
|
||||
]
|
||||
# Define sampling parameters
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Define sampling parameters
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Initialize the LLM engine with the OPT-125M model
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
|
||||
# Initialize the LLM engine with the OPT-125M model
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate outputs for the input prompts
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Generate outputs for the input prompts
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the generated outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
# Print the generated outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
More API details can be found in the :doc:`Offline Inference
|
||||
More API details can be found in the {doc}`Offline Inference
|
||||
</dev/offline_inference/offline_index>` section of the API docs.
|
||||
|
||||
The code for the `LLM` class can be found in `vllm/entrypoints/llm.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py>`_.
|
||||
The code for the `LLM` class can be found in [vllm/entrypoints/llm.py](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py).
|
||||
|
||||
OpenAI-compatible API server
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### OpenAI-compatible API server
|
||||
|
||||
The second primary interface to vLLM is via its OpenAI-compatible API server.
|
||||
This server can be started using the `vllm serve` command.
|
||||
|
||||
.. code-block:: bash
|
||||
```bash
|
||||
vllm serve <model>
|
||||
```
|
||||
|
||||
vllm serve <model>
|
||||
|
||||
The code for the `vllm` CLI can be found in `vllm/scripts.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/scripts.py>`_.
|
||||
The code for the `vllm` CLI can be found in [vllm/scripts.py](https://github.com/vllm-project/vllm/blob/main/vllm/scripts.py).
|
||||
|
||||
Sometimes you may see the API server entrypoint used directly instead of via the
|
||||
`vllm` CLI command. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server --model <model>
|
||||
```
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server --model <model>
|
||||
That code can be found in [vllm/entrypoints/openai/api_server.py](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py).
|
||||
|
||||
That code can be found in `vllm/entrypoints/openai/api_server.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py>`_.
|
||||
|
||||
More details on the API server can be found in the :doc:`OpenAI Compatible
|
||||
More details on the API server can be found in the {doc}`OpenAI Compatible
|
||||
Server </serving/openai_compatible_server>` document.
|
||||
|
||||
LLM Engine
|
||||
----------
|
||||
## LLM Engine
|
||||
|
||||
The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of
|
||||
the vLLM system, handling model inference and asynchronous request processing.
|
||||
|
||||
.. image:: /assets/design/arch_overview/llm_engine.excalidraw.png
|
||||
:alt: LLMEngine Diagram
|
||||
```{image} /assets/design/arch_overview/llm_engine.excalidraw.png
|
||||
:alt: LLMEngine Diagram
|
||||
```
|
||||
|
||||
LLMEngine
|
||||
^^^^^^^^^
|
||||
### LLMEngine
|
||||
|
||||
The `LLMEngine` class is the core component of the vLLM engine. It is
|
||||
responsible for receiving requests from clients and generating outputs from the
|
||||
@ -105,21 +99,15 @@ processing.
|
||||
|
||||
- **Input Processing**: Handles tokenization of input text using the specified
|
||||
tokenizer.
|
||||
|
||||
- **Scheduling**: Chooses which requests are processed in each step.
|
||||
|
||||
- **Model Execution**: Manages the execution of the language model, including
|
||||
distributed execution across multiple GPUs.
|
||||
|
||||
- **Output Processing**: Processes the outputs generated by the model, decoding the
|
||||
token IDs from a language model into human-readable text.
|
||||
|
||||
The code for `LLMEngine` can be found in `vllm/engine/llm_engine.py`_.
|
||||
The code for `LLMEngine` can be found in [vllm/engine/llm_engine.py].
|
||||
|
||||
.. _vllm/engine/llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py
|
||||
|
||||
AsyncLLMEngine
|
||||
^^^^^^^^^^^^^^
|
||||
### AsyncLLMEngine
|
||||
|
||||
The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class.
|
||||
It uses `asyncio` to create a background loop that continuously processes
|
||||
@ -128,54 +116,46 @@ can handle multiple concurrent requests and stream outputs to clients.
|
||||
|
||||
The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo
|
||||
API server that serves as a simpler example in
|
||||
`vllm/entrypoints/api_server.py`_.
|
||||
[vllm/entrypoints/api_server.py].
|
||||
|
||||
.. _vllm/entrypoints/api_server.py: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py
|
||||
The code for `AsyncLLMEngine` can be found in [vllm/engine/async_llm_engine.py].
|
||||
|
||||
The code for `AsyncLLMEngine` can be found in `vllm/engine/async_llm_engine.py`_.
|
||||
|
||||
.. _vllm/engine/async_llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py
|
||||
|
||||
Worker
|
||||
------
|
||||
## Worker
|
||||
|
||||
A worker is a process that runs the model inference. vLLM follows the common
|
||||
practice of using one process to control one accelerator device, such as GPUs.
|
||||
For example, if we use tensor parallelism of size 2 and pipeline parallelism of
|
||||
size 2, we will have 4 workers in total. Workers are identified by their
|
||||
``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while
|
||||
``local_rank`` is mainly used for assigning the accelerator device and accessing
|
||||
`rank` and `local_rank`. `rank` is used for global orchestration, while
|
||||
`local_rank` is mainly used for assigning the accelerator device and accessing
|
||||
local resources such as the file system and shared memory.
|
||||
|
||||
Model Runner
|
||||
------------
|
||||
## Model Runner
|
||||
|
||||
Every worker has one model runner object, responsible for loading and running
|
||||
the model. Much of the model execution logic resides here, such as preparing
|
||||
input tensors and capturing cudagraphs.
|
||||
|
||||
Model
|
||||
-----
|
||||
## Model
|
||||
|
||||
Every model runner object has one model object, which is the actual
|
||||
``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various
|
||||
`torch.nn.Module` instance. See [huggingface_integration](#huggingface-integration) for how various
|
||||
configurations affect the class we ultimately get.
|
||||
|
||||
Class Hierarchy
|
||||
---------------
|
||||
## Class Hierarchy
|
||||
|
||||
The following figure shows the class hierarchy of vLLM:
|
||||
|
||||
.. figure:: /assets/design/hierarchy.png
|
||||
:alt: query
|
||||
:width: 100%
|
||||
:align: center
|
||||
> ```{figure} /assets/design/hierarchy.png
|
||||
> :align: center
|
||||
> :alt: query
|
||||
> :width: 100%
|
||||
> ```
|
||||
|
||||
There are several important design choices behind this class hierarchy:
|
||||
|
||||
1. **Extensibility**: All classes in the hierarchy accept a configuration object
|
||||
containing all the necessary information. The `VllmConfig
|
||||
<https://github.com/vllm-project/vllm/blob/d1c6799b8870e513bf4f2305cbf6cda9fc3d773b/vllm/config.py#L2036>`__
|
||||
1\. **Extensibility**: All classes in the hierarchy accept a configuration object
|
||||
containing all the necessary information. The [VllmConfig](https://github.com/vllm-project/vllm/blob/d1c6799b8870e513bf4f2305cbf6cda9fc3d773b/vllm/config.py#L2036)
|
||||
class is the main configuration object that is passed around. The class
|
||||
hierarchy is quite deep, and every class needs to read the configuration it is
|
||||
interested in. By encapsulating all configurations in one object, we can easily
|
||||
@ -188,7 +168,7 @@ the `VllmConfig` class, and the model runner can access it directly. We don't
|
||||
need to change the constructor of the engine, worker, or model class to pass the
|
||||
new configuration option.
|
||||
|
||||
2. **Uniformity**: The model runner needs a unified interface to create and
|
||||
2\. **Uniformity**: The model runner needs a unified interface to create and
|
||||
initialize the model. vLLM supports more than 50 types of popular open-source
|
||||
models. Each model has its own initialization logic. If the constructor
|
||||
signature varies with models, the model runner does not know how to call the
|
||||
@ -200,46 +180,46 @@ of a vision model and a language model. By making the constructor uniform, we
|
||||
can easily create a vision model and a language model and compose them into a
|
||||
vision-language model.
|
||||
|
||||
.. note::
|
||||
````{note}
|
||||
To support this change, all vLLM models' signatures have been updated to:
|
||||
|
||||
To support this change, all vLLM models' signatures have been updated to:
|
||||
```python
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
```
|
||||
|
||||
.. code-block:: python
|
||||
To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one:
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
```python
|
||||
class MyOldModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
|
||||
To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one:
|
||||
from vllm.config import VllmConfig
|
||||
class MyNewModel(MyOldModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
.. code-block:: python
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
```
|
||||
|
||||
class MyOldModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
This way, the model can work with both old and new versions of vLLM.
|
||||
````
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
class MyNewModel(MyOldModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
|
||||
This way, the model can work with both old and new versions of vLLM.
|
||||
|
||||
3. **Sharding and Quantization at Initialization**: Certain features require
|
||||
3\. **Sharding and Quantization at Initialization**: Certain features require
|
||||
changing the model weights. For example, tensor parallelism needs to shard the
|
||||
model weights, and quantization needs to quantize the model weights. There are
|
||||
two possible ways to implement this feature. One way is to change the model
|
||||
@ -252,23 +232,27 @@ initialized, we need to load the full 810GB weights to every GPU and then shard
|
||||
the weights, leading to a huge memory overhead. Instead, if we shard the weights
|
||||
during the model initialization, every layer will only create a shard of the
|
||||
weights it needs, leading to a much smaller memory overhead. The same idea
|
||||
applies to quantization. Note that we also add an additional argument ``prefix``
|
||||
applies to quantization. Note that we also add an additional argument `prefix`
|
||||
to the model's constructor so that the model can initialize itself differently
|
||||
based on the prefix. This is useful for non-uniform quantization, where
|
||||
different parts of the model are quantized differently. The ``prefix`` is
|
||||
usually an empty string for the top-level model and a string like ``"vision"``
|
||||
or ``"language"`` for the sub-models. In general, it matches the name of the
|
||||
different parts of the model are quantized differently. The `prefix` is
|
||||
usually an empty string for the top-level model and a string like `"vision"`
|
||||
or `"language"` for the sub-models. In general, it matches the name of the
|
||||
module's state dict in the checkpoint file.
|
||||
|
||||
One disadvantage of this design is that it is hard to write unit tests for
|
||||
individual components in vLLM because every component needs to be initialized by
|
||||
a complete config object. We solve this problem by providing a default
|
||||
initialization function that creates a default config object with all fields set
|
||||
to ``None``. If the component we want to test only cares about a few fields in
|
||||
to `None`. If the component we want to test only cares about a few fields in
|
||||
the config object, we can create a default config object and set the fields we
|
||||
care about. This way, we can test the component in isolation. Note that many
|
||||
tests in vLLM are end-to-end tests that test the whole system, so this is not a
|
||||
big problem.
|
||||
|
||||
In summary, the complete config object ``VllmConfig`` can be treated as an
|
||||
In summary, the complete config object `VllmConfig` can be treated as an
|
||||
engine-level global state that is shared among all vLLM classes.
|
||||
|
||||
[vllm/engine/async_llm_engine.py]: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py
|
||||
[vllm/engine/llm_engine.py]: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py
|
||||
[vllm/entrypoints/api_server.py]: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py
|
||||
36
docs/source/design/huggingface_integration.md
Normal file
36
docs/source/design/huggingface_integration.md
Normal file
@ -0,0 +1,36 @@
|
||||
(huggingface-integration)=
|
||||
|
||||
# Integration with HuggingFace
|
||||
|
||||
This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`.
|
||||
|
||||
Let's say we want to serve the popular QWen model by running `vllm serve Qwen/Qwen2-7B`.
|
||||
|
||||
1. The `model` argument is `Qwen/Qwen2-7B`. vLLM determines whether this model exists by checking for the corresponding config file `config.json`. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182) for the implementation. Within this process:
|
||||
|
||||
- If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path.
|
||||
- If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works.
|
||||
- If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file.
|
||||
|
||||
2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186) for the implementation.
|
||||
|
||||
3. Next, vLLM [inspects](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189) the `model_type` field in the config dictionary to [generate](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#190-L216) the config object to use. There are some `model_type` values that vLLM directly supports; see [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48) for the list. If the `model_type` is not in the list, vLLM will use [AutoConfig.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained) to load the config class, with `model`, `--revision`, and `--trust_remote_code` as the arguments. Please note that:
|
||||
|
||||
- HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example.
|
||||
- The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled.
|
||||
|
||||
4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see [here](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244) for the implementation.
|
||||
|
||||
5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the `architectures` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in [its registry](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80). If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For `Qwen/Qwen2-7B`, the `architectures` field is `["Qwen2ForCausalLM"]`, which corresponds to the `Qwen2ForCausalLM` class in [vLLM's code](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364). This class will initialize itself depending on various configs.
|
||||
|
||||
Beyond that, there are two more things vLLM depends on HuggingFace for.
|
||||
|
||||
1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24).
|
||||
|
||||
2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights.
|
||||
|
||||
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that:
|
||||
|
||||
This completes the integration between vLLM and HuggingFace.
|
||||
|
||||
In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository.
|
||||
@ -1,40 +0,0 @@
|
||||
.. _huggingface_integration:
|
||||
|
||||
Integration with HuggingFace
|
||||
===================================
|
||||
|
||||
This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run ``vllm serve``.
|
||||
|
||||
Let's say we want to serve the popular QWen model by running ``vllm serve Qwen/Qwen2-7B``.
|
||||
|
||||
1. The ``model`` argument is ``Qwen/Qwen2-7B``. vLLM determines whether this model exists by checking for the corresponding config file ``config.json``. See this `code snippet <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182>`__ for the implementation. Within this process:
|
||||
|
||||
- If the ``model`` argument corresponds to an existing local path, vLLM will load the config file directly from this path.
|
||||
|
||||
- If the ``model`` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the ``model`` argument as the model name and the ``--revision`` argument as the revision. See `their website <https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome>`__ for more information on how the HuggingFace cache works.
|
||||
|
||||
- If the ``model`` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to `this function <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91>`__ for the implementation. The input arguments include the ``model`` argument as the model name, the ``--revision`` argument as the revision, and the environment variable ``HF_TOKEN`` as the token to access the model hub. In our case, vLLM will download the `config.json <https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json>`__ file.
|
||||
|
||||
2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this `code snippet <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186>`__ for the implementation.
|
||||
|
||||
3. Next, vLLM `inspects <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189>`__ the ``model_type`` field in the config dictionary to `generate <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#190-L216>`__ the config object to use. There are some ``model_type`` values that vLLM directly supports; see `here <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48>`__ for the list. If the ``model_type`` is not in the list, vLLM will use `AutoConfig.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`__ to load the config class, with ``model``, ``--revision``, and ``--trust_remote_code`` as the arguments. Please note that:
|
||||
|
||||
- HuggingFace also has its own logic to determine the config class to use. It will again use the ``model_type`` field to search for the class name in the transformers library; see `here <https://github.com/huggingface/transformers/tree/main/src/transformers/models>`__ for the list of supported models. If the ``model_type`` is not found, HuggingFace will use the ``auto_map`` field from the config JSON file to determine the class name. Specifically, it is the ``AutoConfig`` field under ``auto_map``. See `DeepSeek <https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json>`__ for an example.
|
||||
|
||||
- The ``AutoConfig`` field under ``auto_map`` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the ``from_pretrained`` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when ``--trust_remote_code`` is enabled.
|
||||
|
||||
4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see `here <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244>`__ for the implementation.
|
||||
|
||||
5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the ``architectures`` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in `its registry <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80>`__. If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For ``Qwen/Qwen2-7B``, the ``architectures`` field is ``["Qwen2ForCausalLM"]``, which corresponds to the ``Qwen2ForCausalLM`` class in `vLLM's code <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364>`__. This class will initialize itself depending on various configs.
|
||||
|
||||
Beyond that, there are two more things vLLM depends on HuggingFace for.
|
||||
|
||||
1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using `AutoTokenizer.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`__ with the ``model`` argument as the model name and the ``--revision`` argument as the revision. It is also possible to use a tokenizer from another model by specifying the ``--tokenizer`` argument in the ``vllm serve`` command. Other relevant arguments are ``--tokenizer-revision`` and ``--tokenizer-mode``. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the `get_tokenizer <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87>`__ function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in `get_cached_tokenizer <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24>`__.
|
||||
|
||||
2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the ``model`` argument as the model name and the ``--revision`` argument as the revision. vLLM provides the argument ``--load-format`` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass ``--load-format dummy`` to skip downloading the weights.
|
||||
|
||||
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the `documentation <https://huggingface.co/docs/safetensors/en/index>`__ for more information on the safetensors format. This part of the logic can be found `here <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385>`__. Please note that:
|
||||
|
||||
This completes the integration between vLLM and HuggingFace.
|
||||
|
||||
In summary, vLLM reads the config file ``config.json``, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository.
|
||||
@ -0,0 +1,19 @@
|
||||
(input-processing-pipeline)=
|
||||
|
||||
# Input Processing Pipeline
|
||||
|
||||
1. Input data is passed to {class}`~vllm.LLMEngine` (or {class}`~vllm.AsyncLLMEngine`).
|
||||
|
||||
2. Tokenize the data if necessary.
|
||||
|
||||
3. Process the inputs using {meth}`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
||||
|
||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
||||
|
||||
4. Send the processed inputs to {class}`~vllm.executor.executor_base.ExecutorBase`.
|
||||
|
||||
5. Distribute the inputs via {class}`~vllm.worker.worker_base.WorkerBase` to {class}`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using {meth}`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a {class}`PIL.Image.Image` input to its pixel values for a vision model.
|
||||
@ -1,20 +0,0 @@
|
||||
.. _input_processing_pipeline:
|
||||
|
||||
Input Processing Pipeline
|
||||
=========================
|
||||
|
||||
1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).
|
||||
|
||||
2. Tokenize the data if necessary.
|
||||
|
||||
3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
||||
|
||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
||||
|
||||
4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.
|
||||
|
||||
5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
|
||||
43
docs/source/design/input_processing/model_inputs_index.md
Normal file
43
docs/source/design/input_processing/model_inputs_index.md
Normal file
@ -0,0 +1,43 @@
|
||||
(input-processing)=
|
||||
|
||||
# Input Processing
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: vllm.inputs
|
||||
```
|
||||
|
||||
Each model can override parts of vLLM's [input processing pipeline](#input-processing-pipeline) via
|
||||
{data}`~vllm.inputs.INPUT_REGISTRY` and {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
Currently, this mechanism is only utilized in [multi-modal](#multi-modality) models for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
## Guides
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
input_processing_pipeline
|
||||
```
|
||||
|
||||
## Module Contents
|
||||
|
||||
### LLM Engine Inputs
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Registry
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.inputs.registry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -1,39 +0,0 @@
|
||||
.. _input_processing:
|
||||
|
||||
Input Processing
|
||||
================
|
||||
|
||||
.. currentmodule:: vllm.inputs
|
||||
|
||||
Each model can override parts of vLLM's :ref:`input processing pipeline <input_processing_pipeline>` via
|
||||
:data:`~vllm.inputs.INPUT_REGISTRY` and :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
Currently, this mechanism is only utilized in :ref:`multi-modal <multi_modality>` models for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
input_processing_pipeline
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
LLM Engine Inputs
|
||||
-----------------
|
||||
|
||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
||||
|
||||
.. automodule:: vllm.inputs.registry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
527
docs/source/design/kernel/paged_attention.md
Normal file
527
docs/source/design/kernel/paged_attention.md
Normal file
@ -0,0 +1,527 @@
|
||||
# vLLM Paged Attention
|
||||
|
||||
- Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (`csrc/attention/attention_kernels.cu`).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
- To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
- Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
|
||||
## Inputs
|
||||
|
||||
- The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
```cpp
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
```
|
||||
|
||||
- There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. `scalar_t`
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. `HEAD_SIZE` indicates the number of elements in each
|
||||
head. `BLOCK_SIZE` refers to the number of tokens in each block.
|
||||
`NUM_THREADS` denotes the number of threads in each thread block.
|
||||
`PARTITION_SIZE` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
|
||||
- With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
|
||||
## Concepts
|
||||
|
||||
- Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
- **Sequence**: A sequence represents a client request. For example,
|
||||
the data pointed to by `q` has a shape of
|
||||
`[num_seqs, num_heads, head_size]`. That represents there are total
|
||||
`num_seqs` of query sequence data are pointed by `q`. Since this
|
||||
kernel is a single query attention kernel, each sequence only has one
|
||||
query token. Hence, the `num_seqs` equals the total number of tokens
|
||||
that are processed in the batch.
|
||||
- **Context**: The context consists of the generated tokens from the
|
||||
sequence. For instance, `["What", "is", "your"]` are the context
|
||||
tokens, and the input query token is `"name"`. The model might
|
||||
generate the token `"?"`.
|
||||
- **Vec**: The vec is a list of elements that are fetched and
|
||||
calculated together. For query and key data, the vec size
|
||||
(`VEC_SIZE`) is determined so that each thread group can fetch and
|
||||
calculate 16 bytes of data at a time. For value data, the vec size
|
||||
(`V_VEC_SIZE`) is determined so that each thread can fetch and
|
||||
calculate 16 bytes of data at a time. For example, if the
|
||||
`scalar_t` is FP16 (2 bytes) and `THREAD_GROUP_SIZE` is 2, the
|
||||
`VEC_SIZE` will be 4, while the `V_VEC_SIZE` will be 8.
|
||||
- **Thread group**: The thread group is a small group of
|
||||
threads(`THREAD_GROUP_SIZE`) that fetches and calculates one
|
||||
query token and one key token at a time. Each thread handles only a
|
||||
portion of the token data. The total number of elements processed by
|
||||
one thread group is referred as `x`. For example, if the thread
|
||||
group contains 2 threads and the head size is 8, then thread 0
|
||||
handles the query and key elements at index 0, 2, 4, 6, while thread
|
||||
1 handles the elements at index 1, 3, 5, 7.
|
||||
- **Block**: The key and value cache data in vLLM are split into
|
||||
blocks. Each block stores data for a fixed number(`BLOCK_SIZE`)
|
||||
of tokens at one head. Each block may contain only a portion of the
|
||||
whole context tokens. For example, if the block size is 16 and the
|
||||
head size is 128, then for one head, one block can store 16 * 128 =
|
||||
2048 elements.
|
||||
- **Warp**: A warp is a group of 32 threads(`WARP_SIZE`) that
|
||||
execute simultaneously on a stream multiprocessor (SM). In this
|
||||
kernel, each warp processes the calculation between one query token
|
||||
and key tokens of one entire block at a time (it may process multiple
|
||||
blocks in multiple iterations). For example, if there are 4 warps and
|
||||
6 blocks for one context, the assignment would be like warp 0 handles
|
||||
the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2
|
||||
handles the 2nd block and warp 3 handles the 3rd block.
|
||||
- **Thread block**: A thread block is a group of
|
||||
threads(`NUM_THREADS`) that can access the same shared memory.
|
||||
Each thread block contains multiple warps(`NUM_WARPS`), and in
|
||||
this kernel, each thread block processes the calculation between one
|
||||
query token and key tokens of a whole context.
|
||||
- **Grid**: A grid is a collection of thread blocks and defines the
|
||||
shape of the collection. In this kernel, the shape is
|
||||
`(num_heads, num_seqs, max_num_partitions)`. Therefore, each thread
|
||||
block only handles the calculation for one head, one sequence, and
|
||||
one partition.
|
||||
|
||||
## Query
|
||||
|
||||
- This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/query.png
|
||||
:align: center
|
||||
:alt: query
|
||||
:width: 70%
|
||||
|
||||
Query data of one token at one head
|
||||
```
|
||||
|
||||
- Each thread defines its own `q_ptr` which points to the assigned
|
||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
```{figure} ../../assets/kernel/q_vecs.png
|
||||
:align: center
|
||||
:alt: q_vecs
|
||||
:width: 70%
|
||||
|
||||
`q_vecs` for one thread group
|
||||
```
|
||||
|
||||
```cpp
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
```
|
||||
|
||||
- Next, we need to read the global memory data pointed to by `q_ptr`
|
||||
into shared memory as `q_vecs`. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
|
||||
## Key
|
||||
|
||||
- Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
```
|
||||
|
||||
- Unlike to `q_ptr`, `k_ptr` in each thread will point to different
|
||||
key token at different iterations. As shown above, that `k_ptr`
|
||||
points to key token data based on `k_cache` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
```{figure} ../../assets/kernel/key.png
|
||||
:align: center
|
||||
:alt: key
|
||||
:width: 70%
|
||||
|
||||
Key data of all context tokens at one head
|
||||
```
|
||||
|
||||
- The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||
8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
```{figure} ../../assets/kernel/k_vecs.png
|
||||
:align: center
|
||||
:alt: k_vecs
|
||||
:width: 70%
|
||||
|
||||
`k_vecs` for one thread
|
||||
```
|
||||
|
||||
```cpp
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
```
|
||||
|
||||
- Next, we need to read the key token data from `k_ptr` and store
|
||||
them on register memory as `k_vecs`. We use register memory for
|
||||
`k_vecs` because it will only be accessed by one thread once,
|
||||
whereas `q_vecs` will be accessed by multiple threads multiple
|
||||
times. Each `k_vecs` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
|
||||
- You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
|
||||
## QK
|
||||
|
||||
- As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
`q_vecs` and each `k_vecs`.
|
||||
|
||||
```cpp
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
k_vecs[i] = ...
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
```
|
||||
|
||||
- As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the `Qk_dot<>::dot` . So `qk`
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
|
||||
- For example, if the value of `HEAD_SIZE` is 128 and
|
||||
`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain
|
||||
total 64 elements. However, the returned `qk` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
`Qk_dot<>::dot`. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
|
||||
## Softmax
|
||||
|
||||
- Next, we need to calculate the normalized softmax for all `qk`s,
|
||||
as shown above, where each $x$ represents a `qk`. To do this,
|
||||
we must obtain the reduced value of `qk_max`($m(x)$) and
|
||||
the `exp_sum`($\ell(x)$) of all `qk`s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
|
||||
```{math}
|
||||
:nowrap: true
|
||||
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
```
|
||||
|
||||
### `qk_max` and `logits`
|
||||
|
||||
- Just right after we get the `qk` result, we can set the temporary
|
||||
`logits` result with `qk` (In the end, the `logits` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the `qk_max` for all `qk`s that are calculated by current
|
||||
thread group.
|
||||
|
||||
```cpp
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
```
|
||||
|
||||
- Please note that the `logits` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
|
||||
```cpp
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
```
|
||||
|
||||
- Then we need to get the reduced `qk_max` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max `qk` .
|
||||
|
||||
```cpp
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
```
|
||||
|
||||
- Finally, we can get the reduced `qk_max` from whole thread block by
|
||||
compare the `qk_max` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
|
||||
### `exp_sum`
|
||||
|
||||
- Similar to `qk_max`, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
|
||||
```cpp
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
```
|
||||
|
||||
- Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of `logits` from `qk` to `exp(qk - qk_max)`.
|
||||
Please note, the `qk_max` here is already the max `qk` across the
|
||||
whole thread block. And then we can do reduction for `exp_sum`
|
||||
across whole thread block just like the `qk_max`.
|
||||
|
||||
```cpp
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain
|
||||
the final normalized softmax result as `logits`. This `logits`
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
`qk` for all assigned context tokens.
|
||||
|
||||
## Value
|
||||
|
||||
```{figure} ../../assets/kernel/value.png
|
||||
:align: center
|
||||
:alt: value
|
||||
:width: 70%
|
||||
|
||||
Value data of all context tokens at one head
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/logits_vec.png
|
||||
:align: center
|
||||
:alt: logits_vec
|
||||
:width: 50%
|
||||
|
||||
`logits_vec` for one thread
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/v_vec.png
|
||||
:align: center
|
||||
:alt: v_vec
|
||||
:width: 70%
|
||||
|
||||
List of `v_vec` for one thread
|
||||
```
|
||||
|
||||
- Now we need to retrieve the value data and perform dot multiplication
|
||||
with `logits`. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are `HEAD_SIZE` of
|
||||
rows and `BLOCK_SIZE` of columns that are split into multiple
|
||||
`v_vecs`.
|
||||
|
||||
- Each thread always fetches `V_VEC_SIZE` elements from the same
|
||||
`V_VEC_SIZE` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple `v_vec`s from different rows and the same
|
||||
columns through multiple inner iterations. For each `v_vec`, it
|
||||
needs to be dot multiplied with the corresponding `logits_vec`,
|
||||
which is also `V_VEC_SIZE` elements from `logits`. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processd
|
||||
|
||||
```cpp
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- As shown in the above pseudo code, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
`v_vec` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in `accs`. Therefore, each entry of `accs` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
|
||||
- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If `HEAD_SIZE`
|
||||
is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to
|
||||
fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are
|
||||
a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
## LV
|
||||
|
||||
- Now, we need to perform reduction for `accs` within each warp. This
|
||||
process allows each thread to accumulate the `accs` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
```
|
||||
|
||||
- Next, we perform reduction for `accs` across all warps, allowing
|
||||
each thread to have the accumulation of `accs` for the assigned
|
||||
head positions of all context tokens. Please note that each `accs`
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
- Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
|
||||
```cpp
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
- First, we need to define the `out_ptr` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
`out_ptr`.
|
||||
@ -1,525 +0,0 @@
|
||||
vLLM Paged Attention
|
||||
====================
|
||||
|
||||
- Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (``csrc/attention/attention_kernels.cu``).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
- To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
- Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
|
||||
Inputs
|
||||
------
|
||||
|
||||
- The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers ``q``, ``k_cache``, and ``v_cache``, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer ``out`` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
|
||||
- There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. ``scalar_t``
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. ``HEAD_SIZE`` indicates the number of elements in each
|
||||
head. ``BLOCK_SIZE`` refers to the number of tokens in each block.
|
||||
``NUM_THREADS`` denotes the number of threads in each thread block.
|
||||
``PARTITION_SIZE`` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
- With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
|
||||
Concepts
|
||||
--------
|
||||
|
||||
- Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
- **Sequence**: A sequence represents a client request. For example,
|
||||
the data pointed to by ``q`` has a shape of
|
||||
``[num_seqs, num_heads, head_size]``. That represents there are total
|
||||
``num_seqs`` of query sequence data are pointed by ``q``. Since this
|
||||
kernel is a single query attention kernel, each sequence only has one
|
||||
query token. Hence, the ``num_seqs`` equals the total number of tokens
|
||||
that are processed in the batch.
|
||||
- **Context**: The context consists of the generated tokens from the
|
||||
sequence. For instance, ``["What", "is", "your"]`` are the context
|
||||
tokens, and the input query token is ``"name"``. The model might
|
||||
generate the token ``"?"``.
|
||||
- **Vec**: The vec is a list of elements that are fetched and
|
||||
calculated together. For query and key data, the vec size
|
||||
(``VEC_SIZE``) is determined so that each thread group can fetch and
|
||||
calculate 16 bytes of data at a time. For value data, the vec size
|
||||
(``V_VEC_SIZE``) is determined so that each thread can fetch and
|
||||
calculate 16 bytes of data at a time. For example, if the
|
||||
``scalar_t`` is FP16 (2 bytes) and ``THREAD_GROUP_SIZE`` is 2, the
|
||||
``VEC_SIZE`` will be 4, while the ``V_VEC_SIZE`` will be 8.
|
||||
- **Thread group**: The thread group is a small group of
|
||||
threads(\ ``THREAD_GROUP_SIZE``) that fetches and calculates one
|
||||
query token and one key token at a time. Each thread handles only a
|
||||
portion of the token data. The total number of elements processed by
|
||||
one thread group is referred as ``x``. For example, if the thread
|
||||
group contains 2 threads and the head size is 8, then thread 0
|
||||
handles the query and key elements at index 0, 2, 4, 6, while thread
|
||||
1 handles the elements at index 1, 3, 5, 7.
|
||||
- **Block**: The key and value cache data in vLLM are split into
|
||||
blocks. Each block stores data for a fixed number(\ ``BLOCK_SIZE``)
|
||||
of tokens at one head. Each block may contain only a portion of the
|
||||
whole context tokens. For example, if the block size is 16 and the
|
||||
head size is 128, then for one head, one block can store 16 \* 128 =
|
||||
2048 elements.
|
||||
- **Warp**: A warp is a group of 32 threads(\ ``WARP_SIZE``) that
|
||||
execute simultaneously on a stream multiprocessor (SM). In this
|
||||
kernel, each warp processes the calculation between one query token
|
||||
and key tokens of one entire block at a time (it may process multiple
|
||||
blocks in multiple iterations). For example, if there are 4 warps and
|
||||
6 blocks for one context, the assignment would be like warp 0 handles
|
||||
the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2
|
||||
handles the 2nd block and warp 3 handles the 3rd block.
|
||||
- **Thread block**: A thread block is a group of
|
||||
threads(\ ``NUM_THREADS``) that can access the same shared memory.
|
||||
Each thread block contains multiple warps(\ ``NUM_WARPS``), and in
|
||||
this kernel, each thread block processes the calculation between one
|
||||
query token and key tokens of a whole context.
|
||||
- **Grid**: A grid is a collection of thread blocks and defines the
|
||||
shape of the collection. In this kernel, the shape is
|
||||
``(num_heads, num_seqs, max_num_partitions)``. Therefore, each thread
|
||||
block only handles the calculation for one head, one sequence, and
|
||||
one partition.
|
||||
|
||||
Query
|
||||
-----
|
||||
|
||||
- This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
.. figure:: ../../assets/kernel/query.png
|
||||
:alt: query
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Query data of one token at one head
|
||||
|
||||
- Each thread defines its own ``q_ptr`` which points to the assigned
|
||||
query token data on global memory. For example, if ``VEC_SIZE`` is 4
|
||||
and ``HEAD_SIZE`` is 128, the ``q_ptr`` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
.. figure:: ../../assets/kernel/q_vecs.png
|
||||
:alt: q_vecs
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
``q_vecs`` for one thread group
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
|
||||
- Next, we need to read the global memory data pointed to by ``q_ptr``
|
||||
into shared memory as ``q_vecs``. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
``THREAD_GROUP_SIZE`` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
|
||||
Key
|
||||
---
|
||||
|
||||
- Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
|
||||
- Unlike to ``q_ptr``, ``k_ptr`` in each thread will point to different
|
||||
key token at different iterations. As shown above, that ``k_ptr``
|
||||
points to key token data based on ``k_cache`` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
.. figure:: ../../assets/kernel/key.png
|
||||
:alt: key
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Key data of all context tokens at one head
|
||||
|
||||
- The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the ``BLOCK_SIZE`` is 16, ``HEAD_SIZE`` is 128, ``x`` is
|
||||
8, ``THREAD_GROUP_SIZE`` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
.. figure:: ../../assets/kernel/k_vecs.png
|
||||
:alt: k_vecs
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
``k_vecs`` for one thread
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
|
||||
- Next, we need to read the key token data from ``k_ptr`` and store
|
||||
them on register memory as ``k_vecs``. We use register memory for
|
||||
``k_vecs`` because it will only be accessed by one thread once,
|
||||
whereas ``q_vecs`` will be accessed by multiple threads multiple
|
||||
times. Each ``k_vecs`` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
- You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
|
||||
QK
|
||||
---
|
||||
|
||||
- As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in ``q_vecs``. Then,
|
||||
in the outer for loop, we iterate through different ``k_ptrs`` that
|
||||
point to different tokens and prepare the ``k_vecs`` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
``q_vecs`` and each ``k_vecs``.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
k_vecs[i] = ...
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
|
||||
- As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the ``Qk_dot<>::dot`` . So ``qk``
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
- For example, if the value of ``HEAD_SIZE`` is 128 and
|
||||
``THREAD_GROUP_SIZE`` is 2, each thread's ``k_vecs`` will contain
|
||||
total 64 elements. However, the returned ``qk`` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
``Qk_dot<>::dot``. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
|
||||
Softmax
|
||||
-------
|
||||
|
||||
- Next, we need to calculate the normalized softmax for all ``qk``\ s,
|
||||
as shown above, where each :math:`x` represents a ``qk``. To do this,
|
||||
we must obtain the reduced value of ``qk_max``\ (:math:`m(x)`) and
|
||||
the ``exp_sum``\ (:math:`\ell(x)`) of all ``qk``\ s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
|
||||
``qk_max`` and ``logits``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Just right after we get the ``qk`` result, we can set the temporary
|
||||
``logits`` result with ``qk`` (In the end, the ``logits`` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the ``qk_max`` for all ``qk``\ s that are calculated by current
|
||||
thread group.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
|
||||
- Please note that the ``logits`` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
|
||||
- Then we need to get the reduced ``qk_max`` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max ``qk`` .
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
- Finally, we can get the reduced ``qk_max`` from whole thread block by
|
||||
compare the ``qk_max`` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
|
||||
``exp_sum``
|
||||
~~~~~~~~~~~
|
||||
|
||||
- Similar to ``qk_max``, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
|
||||
- Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of ``logits`` from ``qk`` to ``exp(qk - qk_max)``.
|
||||
Please note, the ``qk_max`` here is already the max ``qk`` across the
|
||||
whole thread block. And then we can do reduction for ``exp_sum``
|
||||
across whole thread block just like the ``qk_max``.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
|
||||
- Finally, with the reduced ``qk_max`` and ``exp_sum``, we can obtain
|
||||
the final normalized softmax result as ``logits``. This ``logits``
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
``qk`` for all assigned context tokens.
|
||||
|
||||
Value
|
||||
-----
|
||||
|
||||
.. figure:: ../../assets/kernel/value.png
|
||||
:alt: value
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Value data of all context tokens at one head
|
||||
|
||||
.. figure:: ../../assets/kernel/logits_vec.png
|
||||
:alt: logits_vec
|
||||
:width: 50%
|
||||
:align: center
|
||||
|
||||
``logits_vec`` for one thread
|
||||
|
||||
.. figure:: ../../assets/kernel/v_vec.png
|
||||
:alt: v_vec
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
List of ``v_vec`` for one thread
|
||||
|
||||
- Now we need to retrieve the value data and perform dot multiplication
|
||||
with ``logits``. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are ``HEAD_SIZE`` of
|
||||
rows and ``BLOCK_SIZE`` of columns that are split into multiple
|
||||
``v_vecs``.
|
||||
- Each thread always fetches ``V_VEC_SIZE`` elements from the same
|
||||
``V_VEC_SIZE`` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple ``v_vec``\ s from different rows and the same
|
||||
columns through multiple inner iterations. For each ``v_vec``, it
|
||||
needs to be dot multiplied with the corresponding ``logits_vec``,
|
||||
which is also ``V_VEC_SIZE`` elements from ``logits``. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processd
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
|
||||
- As shown in the above pseudo code, in the outer loop, similar to
|
||||
``k_ptr``, ``logits_vec`` iterates over different blocks and reads
|
||||
``V_VEC_SIZE`` elements from ``logits``. In the inner loop, each
|
||||
thread reads ``V_VEC_SIZE`` elements from the same tokens as a
|
||||
``v_vec`` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in ``accs``. Therefore, each entry of ``accs`` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
- For example, if ``BLOCK_SIZE`` is 16 and ``V_VEC_SIZE`` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If ``HEAD_SIZE``
|
||||
is 128 and ``WARP_SIZE`` is 32, for each inner loop, a warp needs to
|
||||
fetch ``WARP_SIZE * V_VEC_SIZE = 256`` elements. This means there are
|
||||
a total of 128 \* 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each ``accs`` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the ``accs`` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
LV
|
||||
---
|
||||
- Now, we need to perform reduction for ``accs`` within each warp. This
|
||||
process allows each thread to accumulate the ``accs`` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
|
||||
- Next, we perform reduction for ``accs`` across all warps, allowing
|
||||
each thread to have the accumulation of ``accs`` for the assigned
|
||||
head positions of all context tokens. Please note that each ``accs``
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
|
||||
Output
|
||||
------
|
||||
|
||||
- Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
|
||||
- First, we need to define the ``out_ptr`` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
``out_ptr``.
|
||||
16
docs/source/design/multimodal/adding_multimodal_plugin.md
Normal file
16
docs/source/design/multimodal/adding_multimodal_plugin.md
Normal file
@ -0,0 +1,16 @@
|
||||
(adding-multimodal-plugin)=
|
||||
|
||||
# Adding a Multimodal Plugin
|
||||
|
||||
This document teaches you how to add a new modality to vLLM.
|
||||
|
||||
Each modality in vLLM is represented by a {class}`~vllm.multimodal.MultiModalPlugin` and registered to {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to {meth}`~vllm.multimodal.MultiModalRegistry.register_plugin`.
|
||||
|
||||
The remainder of this document details how to define custom {class}`~vllm.multimodal.MultiModalPlugin` s.
|
||||
|
||||
```{note}
|
||||
This article is a work in progress.
|
||||
```
|
||||
|
||||
% TODO: Add more instructions on how to add new plugins once embeddings is in.
|
||||
@ -1,17 +0,0 @@
|
||||
.. _adding_multimodal_plugin:
|
||||
|
||||
Adding a Multimodal Plugin
|
||||
==========================
|
||||
|
||||
This document teaches you how to add a new modality to vLLM.
|
||||
|
||||
Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`.
|
||||
|
||||
The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s.
|
||||
|
||||
.. note::
|
||||
This article is a work in progress.
|
||||
|
||||
..
|
||||
TODO: Add more instructions on how to add new plugins once embeddings is in.
|
||||
@ -1,66 +1,83 @@
|
||||
.. _multi_modality:
|
||||
(multi-modality)=
|
||||
|
||||
Multi-Modality
|
||||
==============
|
||||
# Multi-Modality
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: vllm.multimodal
|
||||
|
||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||
```
|
||||
|
||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_mm_models>`
|
||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
||||
vLLM provides experimental support for multi-modal models through the {mod}`vllm.multimodal` package.
|
||||
|
||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models)
|
||||
via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`.
|
||||
|
||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||
by following [this guide](#adding-multimodal-plugin).
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](#enabling-multimodal-inputs).
|
||||
|
||||
Guides
|
||||
++++++
|
||||
## Guides
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
adding_multimodal_plugin
|
||||
adding_multimodal_plugin
|
||||
```
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
## Module Contents
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal
|
||||
```
|
||||
|
||||
Registry
|
||||
--------
|
||||
### Registry
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
Base Classes
|
||||
------------
|
||||
### Base Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.MultiModalDataDict
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalKwargs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
Image Classes
|
||||
-------------
|
||||
### Image Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.image
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
54
docs/source/design/plugin_system.md
Normal file
54
docs/source/design/plugin_system.md
Normal file
@ -0,0 +1,54 @@
|
||||
(plugin-system)=
|
||||
|
||||
# vLLM's Plugin System
|
||||
|
||||
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
|
||||
|
||||
## How Plugins Work in vLLM
|
||||
|
||||
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [](#arch-overview)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work.
|
||||
|
||||
## How vLLM Discovers Plugins
|
||||
|
||||
vLLM's plugin system uses the standard Python `entry_points` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
|
||||
|
||||
```python
|
||||
# inside `setup.py` file
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
||||
|
||||
# inside `vllm_add_dummy_model.py` file
|
||||
def register():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava",
|
||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||
```
|
||||
|
||||
For more information on adding entry points to your package, please check the [official documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html).
|
||||
|
||||
Every plugin has three parts:
|
||||
|
||||
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
|
||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
|
||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
|
||||
|
||||
## What Can Plugins Do?
|
||||
|
||||
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
|
||||
|
||||
## Guidelines for Writing Plugins
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
||||
## Compatibility Guarantee
|
||||
|
||||
vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
|
||||
@ -1,62 +0,0 @@
|
||||
.. _plugin_system:
|
||||
|
||||
vLLM's Plugin System
|
||||
====================
|
||||
|
||||
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
|
||||
|
||||
How Plugins Work in vLLM
|
||||
------------------------
|
||||
|
||||
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins <https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16>`__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work.
|
||||
|
||||
How vLLM Discovers Plugins
|
||||
--------------------------
|
||||
|
||||
vLLM's plugin system uses the standard Python ``entry_points`` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# inside `setup.py` file
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
||||
|
||||
# inside `vllm_add_dummy_model.py` file
|
||||
def register():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava",
|
||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||
|
||||
For more information on adding entry points to your package, please check the `official documentation <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`__.
|
||||
|
||||
Every plugin has three parts:
|
||||
|
||||
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group ``vllm.general_plugins`` to register general plugins. This is the key of ``entry_points`` in the ``setup.py`` file. Always use ``vllm.general_plugins`` for vLLM's general plugins.
|
||||
|
||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the ``entry_points`` dictionary. In the example above, the plugin name is ``register_dummy_model``. Plugins can be filtered by their names using the ``VLLM_PLUGINS`` environment variable. To load only a specific plugin, set ``VLLM_PLUGINS`` to the plugin name.
|
||||
|
||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is ``vllm_add_dummy_model:register``, which refers to a function named ``register`` in the ``vllm_add_dummy_model`` module.
|
||||
|
||||
What Can Plugins Do?
|
||||
--------------------
|
||||
|
||||
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
|
||||
|
||||
Guidelines for Writing Plugins
|
||||
------------------------------
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
||||
Compatibility Guarantee
|
||||
-----------------------
|
||||
|
||||
vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
|
||||
Reference in New Issue
Block a user