[Hardware][ROCM] using current_platform.is_rocm (#9642)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-28 12:07:00 +08:00
committed by GitHub
parent 34a9941620
commit 4e2d95e372
32 changed files with 165 additions and 151 deletions

View File

@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.utils import is_hip
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal
@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool,
) -> None:
if backend == "FLASHINFER" and is_hip():
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend

View File

@ -5,7 +5,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip
from vllm.platforms import current_platform
TEST_MODELS = [
("facebook/opt-125m", {}),
@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization": "marlin"
}))
if not is_hip() and is_quant_method_supported("awq"):
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))

View File

@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -6,11 +6,12 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
if not is_hip():
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)

View File

@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.

View File

@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip
from vllm.platforms import current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
@ -82,7 +82,7 @@ class TestResources(NamedTuple):
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
@ -105,10 +105,10 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
@ -156,7 +156,7 @@ def _encoder_attn_setup(
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
@ -169,14 +169,14 @@ def _encoder_attn_setup(
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
@ -265,7 +265,7 @@ def _decoder_attn_setup(
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
@ -275,14 +275,14 @@ def _decoder_attn_setup(
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
@ -436,12 +436,12 @@ def _enc_dec_cross_attn_setup_reuses_query(
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
@ -452,7 +452,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
@ -460,16 +460,16 @@ def _enc_dec_cross_attn_setup_reuses_query(
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
@ -596,7 +596,7 @@ def _run_encoder_attention_test(
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
@ -607,7 +607,7 @@ def _run_encoder_attention_test(
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
@ -646,7 +646,7 @@ def _run_decoder_self_attention_test(
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
@ -694,11 +694,11 @@ def _run_encoder_decoder_cross_attention_test(
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -837,14 +840,14 @@ def test_e2e_enc_dec_attn(
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
@ -864,10 +867,11 @@ def test_e2e_enc_dec_attn(
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of EncoderDecoderModelRunner, which constructs a single attention metadata

View File

@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import is_hip, seed_everything
from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b"
@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,

View File

@ -8,7 +8,7 @@ import pytest
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
@dataclass
@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm.
if is_hip():
if current_platform.is_rocm():
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",

View File

@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
@ -70,7 +71,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
@ -151,7 +152,7 @@ def run_test(
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
current_platform.is_rocm(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"

View File

@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

View File

@ -5,7 +5,7 @@ tensor parallelism.
import pytest
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp
@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs,

View File

@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)
cuda_device_count_stateless, get_open_port)
if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage,
@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
if is_hip():
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10