[torch.compile] limit inductor threads and lazy import quant (#10482)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-20 18:36:33 -08:00
committed by GitHub
parent 2f77b6cfec
commit 388ee3de66
11 changed files with 178 additions and 64 deletions

View File

@ -1,4 +1,4 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform
@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
capability = current_platform.get_device_capability()
assert capability is not None
min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
min_capability = get_quantization_config(quant_method).get_min_capability()
return capability.to_int() >= min_capability

View File

@ -0,0 +1,68 @@
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script
import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator
@dataclasses.dataclass
class BlameResult:
found: bool = False
trace_stack: str = ""
@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.
Usage:
```python
with blame(lambda: some_condition()) as result:
# do something
if result.found:
print(result.trace_stack)
"""
result = BlameResult()
def _trace_calls(frame, event, arg=None):
nonlocal result
if event in ['call', 'return']:
# for every function call or return
try:
# Temporarily disable the trace function
sys.settrace(None)
# check condition here
if not result.found and func():
result.found = True
result.trace_stack = "".join(traceback.format_stack())
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls
sys.settrace(_trace_calls)
yield result
sys.settrace(None)
module_name = "torch._inductor.async_compile"
with blame(lambda: module_name in sys.modules) as result:
import vllm # noqa
assert not result.found, (f"Module {module_name} is already imported, the"
f" first import location is:\n{result.trace_stack}")
print(f"Module {module_name} is not imported yet")