[VLM] Calculate maximum number of multi-modal tokens by model (#6121)

This commit is contained in:
Cyrus Leung
2024-07-05 07:37:23 +08:00
committed by GitHub
parent 69ec3ca14c
commit ae96ef8fbd
12 changed files with 265 additions and 95 deletions

View File

@ -51,17 +51,16 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
2. Register input mappers
-------------------------
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
@ -69,22 +68,22 @@ A default mapper is available for each modality in the core vLLM library. This i
:ref:`input_processing_pipeline`
3. (Optional) Register dummy data
---------------------------------
3. Register maximum number of multimodal tokens
----------------------------------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:
@ -95,7 +94,36 @@ Here are some examples:
:ref:`input_processing_pipeline`
4. (Optional) Register input processor
4. (Optional) Register dummy data
---------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
Here are some examples:
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
5. (Optional) Register input processor
--------------------------------------
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:

View File

@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.
To consume the server, you can use the OpenAI client like in the example below: