Compare commits

...

15 Commits

Author SHA1 Message Date
31c1f3255e Bump up to v0.2.5 (#2095) 2023-12-13 23:56:15 -08:00
21d93c140d Optimize Mixtral with expert parallelism (#2090) 2023-12-13 23:55:07 -08:00
f1c8520146 [BugFix] Fix input positions for long context with sliding window (#2088) 2023-12-13 12:28:13 -08:00
096827c284 [Docs] Add notes on ROCm-supported models (#2087) 2023-12-13 09:45:34 -08:00
6565d9e33e Update installation instruction for vLLM + CUDA 11.8 (#2086) 2023-12-13 09:25:59 -08:00
f375ec8440 [ROCm] Upgrade xformers version for ROCm & update doc (#2079)
Co-authored-by: miloice <jeffaw99@hotmail.com>
2023-12-13 00:56:05 -08:00
518369d78c Implement lazy model loader (#2044) 2023-12-12 22:21:45 -08:00
30bad5c492 Fix peak memory profiling (#2031) 2023-12-12 22:01:53 -08:00
3fefe271ec Update Dockerfile to build Megablocks (#2042) 2023-12-12 17:34:17 -08:00
6428f1d051 Support MPT with GQA (#1938)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2023-12-12 10:16:05 -08:00
7e1b21daac Remove einops from requirements (#2049) 2023-12-12 09:34:09 -08:00
cb3f30c600 Upgrade transformers version to 4.36.0 (#2046) 2023-12-11 18:39:14 -08:00
f3e024bece [CI/CD] Upgrade PyTorch version to v2.1.1 (#2045) 2023-12-11 17:48:11 -08:00
31d2ab4aff Remove python 3.10 requirement (#2040) 2023-12-11 12:26:42 -08:00
eb17212858 Update Dockerfile to support Mixtral (#2027) 2023-12-11 11:59:08 -08:00
27 changed files with 514 additions and 521 deletions

View File

@ -49,7 +49,7 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.1.0']
pytorch-version: ['2.1.1']
cuda-version: ['11.8', '12.1']
steps:

View File

@ -75,7 +75,7 @@ ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate fschat
pip install accelerate
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

View File

@ -47,12 +47,12 @@ RUN mkdir libs \
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.22.post7 --no-deps
RUN pip install xformers==0.0.23 --no-deps
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& bash patch_xformers-0.0.22.post7.rocm.sh \
&& bash patch_xformers-0.0.23.rocm.sh \
&& python3 setup.py install \
&& cd ..

View File

@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash
pip install vllm
```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks) on **Python 3.10**:
```bash
pip install megablocks
```
## Getting Started

View File

@ -3,7 +3,7 @@
Installation with ROCm
======================
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm.
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
Data types currently supported in ROCm are FP16 and BF16.
@ -29,7 +29,7 @@ Installation options:
.. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
$ docker run -it \
--network=host \
--group-add=video \
@ -70,12 +70,12 @@ You can build and install vLLM from source:
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh
$ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers.rocm.sh
3. Build vLLM.
@ -127,12 +127,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh
$ pip install xformers==0.0.23 --no-deps
$ bash patch_xformers.rocm.sh
3. Build vLLM.

View File

@ -20,7 +20,7 @@ You can install vLLM using pip:
.. code-block:: console
$ # (Optional) Create a new conda environment.
$ conda create -n myenv python=3.8 -y
$ conda create -n myenv python=3.9 -y
$ conda activate myenv
$ # Install vLLM with CUDA 12.1.
@ -34,8 +34,9 @@ You can install vLLM using pip:
.. code-block:: console
$ # Install vLLM with CUDA 11.8.
$ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`).
$ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl
$ export VLLM_VERSION=0.2.4
$ export PYTHON_VERSION=39
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
$ # Re-install PyTorch with CUDA 11.8.
$ pip uninstall torch -y

View File

@ -73,6 +73,9 @@ If your model uses one of the above model architectures, you can seamlessly run
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
.. tip::
The easiest way to check if your model is supported is to run the program below:
@ -84,12 +87,17 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
output = llm.generate("Hello, my name is")
print(output)
To use model from www.modelscope.cn
If vLLM successfully generates text, it indicates that your model is supported.
.. tip::
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
.. code-block:: shell
$ export VLLM_USE_MODELSCOPE=True
And use with :code:`trust_remote_code=True`.
.. code-block:: python
from vllm import LLM
@ -97,5 +105,3 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If vLLM successfully generates text, it indicates that your model is supported.

View File

@ -1,21 +1,32 @@
#!/bin/bash
set -e
XFORMERS_VERSION="0.0.23"
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
exit 1
fi
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
echo $XFORMERS_FMHA_FLASH_PATH
echo $XFORMERS_FMHA_COMMON_PATH
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"

View File

@ -4,7 +4,7 @@ requires = [
"ninja",
"packaging",
"setuptools >= 49.4.0",
"torch >= 2.1.0",
"torch >= 2.1.1",
"wheel",
]
build-backend = "setuptools.build_meta"

View File

@ -8,9 +8,7 @@ pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
huggingface_hub<0.18,>=0.16.4
einops # Required for phi-1_5
transformers >= 4.34.0 # Required for Mistral.
transformers >= 4.36.0 # Required for Mixtral.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.

View File

@ -5,10 +5,9 @@ pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
einops # Required for phi-1_5
torch >= 2.1.0
transformers >= 4.34.0 # Required for Mistral.
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
torch >= 2.1.1
transformers >= 4.36.0 # Required for Mixtral.
xformers >= 0.0.23 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.

View File

@ -1,6 +1,6 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@
--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
@@ -36,44 +36,44 @@
FLASH_VERSION = "0.0.0"
try:
@ -15,9 +15,12 @@
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
- if flash_ver_parsed < (2, 3):
- raise ImportError("Requires 2.3 for sliding window support")
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
- if (
- flash_ver_parsed != (2, 3, 6)
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
- ):
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
@ -29,35 +32,41 @@
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+ # if flash_ver_parsed < (2, 3):
+ # raise ImportError("Requires 2.3 for sliding window support")
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+ # if (
+ # flash_ver_parsed != (2, 3, 6)
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+ # ):
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
-
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- "bool is_causal, int window_left, "
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
-
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- "float p, float softmax_scale, bool is_causal, "
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ # "bool is_causal, int window_left, "
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
@ -65,52 +74,61 @@
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ # "float p, float softmax_scale, bool is_causal, "
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -98,8 +98,8 @@
@@ -111,8 +111,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
- window_left, # window_size_left
- window_right, # window_size_right
+ # window_left, # window_size_left
+ # window_right, # window_size_right
return_softmax,
None, # rng
)
@@ -127,8 +127,8 @@
@@ -134,15 +134,15 @@
out,
cu_seq_lens_q,
cu_seq_lens_k,
- seqused_k,
+ # seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
- window_left,
- window_right,
+ # window_left,
+ # window_right,
return_softmax,
None,
)
@@ -169,8 +169,8 @@
@@ -184,8 +184,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
@@ -193,15 +193,15 @@
@@ -208,15 +208,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
- window_left,
- window_right,
+ # window_left,
+ # window_right,
None,
rng_state,
)
@ -123,7 +141,7 @@
except ImportError:
pass
@@ -348,7 +348,7 @@
@@ -400,7 +400,7 @@
implementation.
"""

View File

@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple
import pytest
@ -7,21 +8,32 @@ from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_PROMPTS = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
"Describe the basic components of a neural network and how it can be trained.",
"Write a short story about a robot that dreams for the first time.",
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
]
_TEST_PROMPTS = ["prompts/example.txt"]
_LONG_PROMPTS = ["prompts/summary.txt"]
def _read_prompts(filename: str) -> str:
prompts = []
with open(filename, "r") as f:
prompt = f.readline()
prompts.append(prompt)
return prompts
@pytest.fixture
def example_prompts() -> List[str]:
return _TEST_PROMPTS
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts
_STR_DTYPE_TO_TORCH_DTYPE = {

View File

@ -0,0 +1,37 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py --forked`.
"""
import pytest
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
hf_runner,
vllm_runner,
example_long_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
del vllm_model
for i in range(len(example_long_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -0,0 +1,8 @@
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
Describe the basic components of a neural network and how it can be trained.
Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'

File diff suppressed because one or more lines are too long

View File

@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
__version__ = "0.2.4"
__version__ = "0.2.5"
__all__ = [
"LLM",

View File

@ -120,14 +120,16 @@ class ModelConfig:
if load_format == "auto":
load_format = "pt"
# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt":
logger.info(
"Currently, only 'pt' format is supported for Mixtral. "
"Changing the format to 'pt'. This may re-download the "
"weights if you have downloaded the safetensor weights.")
load_format = "pt"
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
self.load_format = load_format

View File

@ -138,7 +138,8 @@ class PagedAttention(nn.Module):
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, batch_size, seq_len, query.dtype)
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
alibi_slopes.shape[0],
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias

View File

@ -7,54 +7,9 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.model_executor.models import *
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
"YiForCausalLM": YiForCausalLM,
}
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}
@contextlib.contextmanager
@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module:

View File

@ -1,41 +1,82 @@
from vllm.model_executor.models.aquila import AquilaForCausalLM
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mixtral import MixtralForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
from vllm.model_executor.models.yi import YiForCausalLM
import importlib
from typing import List, Optional, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
# Architecture -> (module, class).
_MODELS = {
"AquilaModel": ("aquila", "AquilaForCausalLM"),
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"),
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
}
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
__all__ = [
"AquilaForCausalLM",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"ChatGLMForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM",
"PhiForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
"MixtralForCausalLM",
"YiForCausalLM",
"ModelRegistry",
]

View File

@ -29,25 +29,13 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import MistralConfig
try:
import megablocks.ops as ops
except ImportError:
print(
"MegaBlocks not found. Please install it by `pip install megablocks`. "
"Note that MegaBlocks depends on mosaicml-turbo, which only supports "
"Python 3.10 for now.")
try:
import stk
except ImportError:
print(
"STK not found: please see https://github.com/stanford-futuredata/stk")
from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
@ -67,8 +55,134 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class DummyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(0, 0, bias=False)
self.w2 = nn.Linear(0, 0, bias=False)
self.w3 = nn.Linear(0, 0, bias=False)
set_weight_attrs(self.w1.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w2.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w3.weight,
{"weight_loader": self.dummy_weight_loader})
def forward(self, *args, **kwargs) -> None:
raise NotImplementedError()
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
# Noop
return
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else DummyModule()
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
linear_method=linear_method)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
class MixtralAttention(nn.Module):
@ -79,6 +193,7 @@ class MixtralAttention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -103,24 +218,26 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.wqkv = QKVParallelLinear(
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.wo = RowParallelLinear(
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=False, # weights not in HF format
is_neox_style=True,
)
self.attn = PagedAttention(
self.num_heads,
@ -138,334 +255,93 @@ class MixtralAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.wo(attn_output)
output, _ = self.o_proj(attn_output)
return output
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, hidden_dim: int, ffn_dim: int, num_experts: int,
top_k: int):
super().__init__()
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_experts = num_experts
self.top_k = top_k
# gating
self.gate = nn.Linear(self.hidden_dim,
self.num_experts,
bias=False,
device=torch.cuda.current_device())
tp_size = get_tensor_model_parallel_world_size()
assert self.ffn_dim % tp_size == 0
self.ffn_dim_per_partition = self.ffn_dim // tp_size
# merged expert weights, all of size (ffn_dim * n_experts, model_dim)
self.w1 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w1, {"weight_loader": self.moe_weight_loader})
self.w2 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w2, {"weight_loader": self.moe_weight_loader})
self.w3 = nn.Parameter(
torch.empty(self.ffn_dim_per_partition * self.num_experts,
self.hidden_dim,
device=torch.cuda.current_device()))
set_weight_attrs(self.w3, {"weight_loader": self.moe_weight_loader})
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
# Calculate the number of bits needed to represent the column indices
# in the intermediate sparse matrix.
max_column_index = (self.ffn_dim * self.num_experts) // self.blocking
self.transpose_sort_end_bit = max(
int(np.ceil(np.log2(max_column_index))), 1)
def moe_weight_loader(self, param: nn.Parameter,
loaded_weight: torch.Tensor) -> None:
"""
Load the weights for the MoE linear layer.
"""
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.ffn_dim_per_partition
loaded_weight = loaded_weight.view(self.num_experts, self.ffn_dim, -1)
loaded_weight = loaded_weight[:, shard_size * tp_rank:shard_size *
(tp_rank + 1)]
loaded_weight = loaded_weight.reshape_as(param)
param.data.copy_(loaded_weight)
def sparse_transpose(
self, size: int, row_indices,
column_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
block_columns = size[1] // self.blocking
# Sort row indices by column indices to get the transposed matrix's
# column indices.
#
# NOTE: Our sort operation uses the same width indices as the input
# values. To avoid overflow when we have large activation matrices
# we cast to 32-bit before sorting.
_, gather_indices = ops.sort(column_indices.int(),
self.transpose_sort_end_bit)
# There are a constant number of blocks in every row of the sparse
# matrix. A blocks offset is:
#
# row_index * blocks_per_row + column_index % blocks_per_row
#
# Once we have the block offsets ordered for transposition we can
# divide by blocks_per_row to get the transposed column indices.
column_indices_t = row_indices.gather(0, gather_indices.long())
block_offsets_t = gather_indices.int()
zero = torch.zeros((1, ), dtype=torch.int32, device=row_indices.device)
nnz_per_column = ops.histogram(column_indices, block_columns)
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
offsets_t = torch.cat([zero, nnz_per_column])
return column_indices_t, offsets_t, block_offsets_t
def topology(self, x: torch.Tensor,
padded_bins: torch.Tensor) -> "stk.Matrix":
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim_per_partition % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim_per_partition // self.blocking
offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim_per_partition * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
shape, row_indices, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
column_indices_t,
offsets_t,
block_offsets_t,
)
def indices_and_padded_bins(
self, selected_experts: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
selected_experts = selected_experts.int()
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
self.blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
return indices, bin_ids, bins, padded_bins, tokens_per_expert
@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = F.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.flatten().to(x.dtype)
selected_experts = selected_experts.flatten()
(indices, bin_ids, bins, padded_bins,
_) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
F.silu(stk.ops.sdd(x, self.w1.t(), topo).data) *
stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
x = tensor_model_parallel_all_reduce(x)
# Permute back and remove padding
# (top_k * sequence_length, model_dim)
x = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.top_k,
self.quantize_scatter_num_bits,
)
return x.view(*input_shape)
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MistralConfig,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attention = MixtralAttention(
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window)
self.block_sparse_moe = BlockSparseMoE(
hidden_dim=self.hidden_size,
ffn_dim=config.intermediate_size,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
sliding_window=config.sliding_window,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
x: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> torch.Tensor:
r = self.attention(
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=self.attention_norm(x),
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
h = x + r
r = self.block_sparse_moe(self.ffn_norm(h))
out = h + r
return out
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralForCausalLM(nn.Module):
class MixtralModel(nn.Module):
def __init__(
self,
config: MistralConfig,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
assert linear_method is None
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tok_embeddings = VocabParallelEmbedding(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config)
MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
@ -475,20 +351,42 @@ class MixtralForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.tok_embeddings(input_ids)
# forward
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
cache_event, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
@ -496,7 +394,7 @@ class MixtralForCausalLM(nn.Module):
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.output.weight, hidden_states,
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
@ -507,10 +405,11 @@ class MixtralForCausalLM(nn.Module):
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("wqkv", "wq", "q"),
("wqkv", "wk", "k"),
("wqkv", "wv", "v"),
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):

View File

@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.head_dim = self.d_model // self.total_num_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if "kv_n_heads" in config.attn_config:
self.total_num_kv_heads = config.attn_config['kv_n_heads']
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
linear_method=linear_method,
)
@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)

View File

@ -40,11 +40,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
return int(max_shared_mem)
def get_gpu_memory(gpu: int = 0) -> int:
"""Returns the total memory of the GPU in bytes."""
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total

View File

@ -134,14 +134,14 @@ class ModelRunner:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
context_len = seq_data.get_len()
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
context_lens.append(context_len)
position = context_len - 1
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size

View File

@ -13,7 +13,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.utils import get_gpu_memory
class Worker:
@ -81,7 +80,6 @@ class Worker:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
@ -90,8 +88,9 @@ class Worker:
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated()
total_gpu_memory = get_gpu_memory()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int(