[Bugfix] consider related env vars for torch.compiled cache hash (#14953)

Signed-off-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
DefTruth
2025-03-23 23:53:09 +08:00
committed by GitHub
parent f90d34b498
commit 6ebaf9ac71
2 changed files with 46 additions and 0 deletions

View File

@ -357,6 +357,11 @@ class VllmBackend:
# graph.
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional
@ -651,3 +652,43 @@ def set_vllm_use_v1(use_v1: bool):
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
def compute_hash() -> str:
"""
WARNING: Whenever a new key is added to this environment
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
factors: list[Any] = []
# summarize environment variables
def factorize(name: str):
if __getattr__(name):
factors.append(__getattr__(name))
else:
factors.append("None")
# The values of envs may affects the computation graph.
# TODO(DefTruth): hash all environment variables?
# for key in environment_variables:
# factorize(key)
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
]
for key in environment_variables_to_hash:
if key in environment_variables:
factorize(key)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str