Merge branch 'main' into wentao-refactor-batch-invariant-fp8-deepgemm
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -20,9 +20,6 @@ from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat,
|
||||
@ -436,10 +433,6 @@ class ModelConfig:
|
||||
skip_mm_profiling: bool | None,
|
||||
video_pruning_rate: float | None,
|
||||
) -> None:
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_is_batch_invariant():
|
||||
self.enforce_eager = True
|
||||
|
||||
# Set the default seed to 0 in V1.
|
||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
|
||||
|
||||
@ -251,6 +251,9 @@ def disable_compile_cache() -> bool:
|
||||
|
||||
|
||||
def use_aot_compile() -> bool:
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
default_value = (
|
||||
@ -259,7 +262,10 @@ def use_aot_compile() -> bool:
|
||||
else "0"
|
||||
)
|
||||
|
||||
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
|
||||
return (
|
||||
not vllm_is_batch_invariant()
|
||||
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
|
||||
)
|
||||
|
||||
|
||||
def env_with_choices(
|
||||
|
||||
@ -11,6 +11,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -716,6 +717,10 @@ def linear_batch_invariant(input, weight, bias=None):
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
_original_torch_bmm = None
|
||||
_original_fp16_reduction_precision = None
|
||||
_original_bf16_reduction_precision = None
|
||||
_original_cublas_workspace_cfg = None
|
||||
_original_cublaslt_workspace_size = None
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
@ -724,6 +729,8 @@ def is_batch_invariant_mode_enabled():
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
|
||||
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
@ -745,14 +752,75 @@ def enable_batch_invariant_mode():
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
_original_bf16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
)
|
||||
_original_fp16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
)
|
||||
|
||||
reduced_precision_val = (
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
reduced_precision_val
|
||||
)
|
||||
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
||||
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
|
||||
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
|
||||
if not _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
if _original_torch_bmm is not None:
|
||||
torch.bmm = _original_torch_bmm
|
||||
_original_torch_bmm = None
|
||||
|
||||
if _original_bf16_reduction_precision is not None:
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
_original_bf16_reduction_precision
|
||||
)
|
||||
_original_bf16_reduction_precision = None
|
||||
if _original_fp16_reduction_precision is not None:
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
_original_fp16_reduction_precision
|
||||
)
|
||||
_original_fp16_reduction_precision = None
|
||||
|
||||
torch.backends.cuda.preferred_blas_library(backend="default")
|
||||
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
# Set cublas env vars to previous results. If previous results are None,
|
||||
# that means the env vars were not set, so we should remove them.
|
||||
if _original_cublas_workspace_cfg:
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
|
||||
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
|
||||
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
|
||||
|
||||
if _original_cublaslt_workspace_size:
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
|
||||
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
|
||||
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
|
||||
|
||||
_original_cublas_workspace_cfg = None
|
||||
_original_cublaslt_workspace_size = None
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
@ -831,6 +899,9 @@ def override_envs_for_invariance():
|
||||
os.environ["NCCL_NTHREADS"] = "1"
|
||||
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
|
||||
|
||||
# torch.compile settings
|
||||
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
|
||||
@ -360,6 +360,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
|
||||
@ -306,11 +306,12 @@ class KVCacheManager:
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
)
|
||||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||||
|
||||
@ -151,7 +151,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
|
||||
Reference in New Issue
Block a user