Separate attention backends (#3005)

This commit is contained in:
Woosuk Kwon
2024-03-07 01:45:50 -08:00
committed by GitHub
parent cbf4c05b15
commit 2daf23ab0c
35 changed files with 561 additions and 271 deletions

View File

@ -3,6 +3,7 @@ import io
import os
import re
import subprocess
import sys
import warnings
from pathlib import Path
from typing import List, Set
@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
@ -324,8 +327,46 @@ if _is_cuda():
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()
# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)
# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):
def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)
class BinaryDistribution(setuptools.Distribution):
def has_ext_modules(self):
return True
else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version()
vllm_extension_sources = [
"csrc/cache_kernels.cu",
@ -468,6 +509,7 @@ setuptools.setup(
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
distclass=BinaryDistribution,
package_data=package_data,
)