Compare commits
2 Commits
use-uv-pyt
...
woosuk/fix
| Author | SHA1 | Date | |
|---|---|---|---|
| 936da0f740 | |||
| 20098c10d9 |
@ -82,7 +82,7 @@ class CUDAGraphWrapper:
|
|||||||
# TODO: in the future, if we want to use multiple
|
# TODO: in the future, if we want to use multiple
|
||||||
# streams, it might not be safe to share a global pool.
|
# streams, it might not be safe to share a global pool.
|
||||||
# only investigate this when we use multiple streams
|
# only investigate this when we use multiple streams
|
||||||
self.graph_pool = current_platform.get_global_graph_pool()
|
self.graph_pool = current_platform.graph_pool_handle()
|
||||||
|
|
||||||
if cudagraph_options is None:
|
if cudagraph_options is None:
|
||||||
cudagraph_options = CUDAGraphOptions()
|
cudagraph_options = CUDAGraphOptions()
|
||||||
|
|||||||
@ -140,8 +140,6 @@ class Platform:
|
|||||||
|
|
||||||
additional_env_vars: list[str] = []
|
additional_env_vars: list[str] = []
|
||||||
|
|
||||||
_global_graph_pool: Optional[Any] = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_dtypes(self) -> list[torch.dtype]:
|
def supported_dtypes(self) -> list[torch.dtype]:
|
||||||
"""Returns the supported dtypes for the current platform."""
|
"""Returns the supported dtypes for the current platform."""
|
||||||
@ -535,15 +533,6 @@ class Platform:
|
|||||||
" attribute.", self.device_type, key)
|
" attribute.", self.device_type, key)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_global_graph_pool(self) -> Any:
|
|
||||||
"""
|
|
||||||
Return the global graph pool for this platform.
|
|
||||||
"""
|
|
||||||
cls = self.__class__
|
|
||||||
if cls._global_graph_pool is None:
|
|
||||||
cls._global_graph_pool = self.graph_pool_handle()
|
|
||||||
return cls._global_graph_pool
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class UBatchWrapper:
|
|||||||
if runtime_mode is not CUDAGraphMode.NONE:
|
if runtime_mode is not CUDAGraphMode.NONE:
|
||||||
self.cudagraph_wrapper = CUDAGraphWrapper(
|
self.cudagraph_wrapper = CUDAGraphWrapper(
|
||||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||||
self.graph_pool = current_platform.get_global_graph_pool()
|
self.graph_pool = current_platform.graph_pool_handle()
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
# allow accessing the attributes of the runnable.
|
# allow accessing the attributes of the runnable.
|
||||||
|
|||||||
Reference in New Issue
Block a user