Compare commits

..

7 Commits

8 changed files with 73 additions and 47 deletions

View File

@ -3,31 +3,15 @@
Installation
============
vLLM is a Python library that also contains some C++ and CUDA code.
This additional code requires compilation on the user's machine.
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
Requirements
------------
* OS: Linux
* Python: 3.8 or higher
* CUDA: 11.0 -- 11.8
* Python: 3.8 -- 3.11
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
.. note::
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
.. tip::
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip
----------------
@ -40,7 +24,7 @@ You can install vLLM using pip:
$ conda activate myenv
$ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes.
$ pip install vllm
.. _build_from_source:
@ -55,3 +39,11 @@ You can also build and install vLLM from source:
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ pip install -e . # This may take 5-10 minutes.
.. tip::
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3

View File

@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@ -54,7 +54,8 @@ if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
@ -65,7 +66,8 @@ if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
@ -78,7 +80,9 @@ if not compute_capabilities:
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
NVCC_FLAGS += [
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
@ -91,7 +95,10 @@ ext_modules = []
cache_extension = CUDAExtension(
name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cache_extension)
@ -99,7 +106,10 @@ ext_modules.append(cache_extension)
attention_extension = CUDAExtension(
name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(attention_extension)
@ -107,7 +117,10 @@ ext_modules.append(attention_extension)
positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(positional_encoding_extension)
@ -115,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(layernorm_extension)
@ -123,7 +139,10 @@ ext_modules.append(layernorm_extension)
activation_extension = CUDAExtension(
name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(activation_extension)
@ -138,8 +157,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
@ -162,7 +181,8 @@ setuptools.setup(
version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team",
license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs"),
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm",
@ -174,11 +194,12 @@ setuptools.setup(
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
"examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,

View File

@ -133,9 +133,10 @@ def test_rotary_embedding(
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)

View File

@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
__version__ = "0.1.6"
__version__ = "0.1.7"
__all__ = [
"LLM",

View File

@ -114,8 +114,9 @@ class ModelConfig:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon"
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):

View File

@ -153,7 +153,7 @@ class LLMEngine:
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs,
)(RayWorker).remote()
)(RayWorker).remote(self.model_config.trust_remote_code)
self.workers.append(worker)
# Initialize torch distributed process group for the workers.

View File

@ -11,7 +11,11 @@ try:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
# pylint: disable=import-outside-toplevel
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
def init_worker(self, worker_init_fn):

View File

@ -73,7 +73,12 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias:
# Already set by a previous layer.
return
@ -196,7 +201,7 @@ class PagedAttention(nn.Module):
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
@ -259,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
self.is_neox_style = is_neox_style
# Create the cos and sin cache.
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
t = torch.arange(max_position, device="cuda").float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
@ -340,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention):
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
@ -364,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention):
prompt_len,
padded_len,
device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)