Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -27,8 +26,8 @@ def ref_paged_attn(
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@ -94,12 +93,12 @@ def test_varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: Optional[int],
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -50,7 +49,7 @@ def ref_masked_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
@ -69,7 +68,7 @@ def ref_single_query_cached_kv_attention(
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
@ -415,7 +414,7 @@ def ref_multi_query_kv_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_bias: Optional[list[torch.Tensor]],
|
||||
alibi_bias: list[torch.Tensor] | None,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -85,7 +84,7 @@ def test_cascade(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -17,7 +16,7 @@ def cal_diff(
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False,
|
||||
diff_threshold: Optional[float] = None,
|
||||
diff_threshold: float | None = None,
|
||||
) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -34,8 +33,8 @@ def ref_paged_attn(
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@ -103,11 +102,11 @@ def test_flash_attn_with_paged_kv(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
sliding_window: Optional[int],
|
||||
sliding_window: int | None,
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
@ -221,13 +220,13 @@ def test_varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: Optional[int],
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
@ -26,8 +25,8 @@ def ref_paged_attn(
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@ -90,8 +89,8 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
sliding_window: Optional[int],
|
||||
soft_cap: float | None,
|
||||
sliding_window: int | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
@ -185,8 +184,8 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
sliding_window: Optional[int],
|
||||
soft_cap: float | None,
|
||||
sliding_window: int | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
@ -288,7 +287,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
) -> None:
|
||||
pytest.skip("TODO: fix the accuracy issue")
|
||||
torch.set_default_device("cuda")
|
||||
@ -398,7 +397,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
@ -68,9 +67,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@ -78,7 +75,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
has_sinks: bool,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@ -267,9 +264,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@ -277,7 +272,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
has_sinks: bool,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -20,7 +19,7 @@ def merge_attn_states_torch(
|
||||
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
):
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -32,8 +31,8 @@ def ref_paged_attn(
|
||||
kv_lens: list[int],
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@ -98,12 +97,12 @@ def test_triton_unified_attn(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: Optional[int],
|
||||
sliding_window: int | None,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -31,13 +30,13 @@ EPS = 1e-6
|
||||
## Helpers
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def ref_rms_norm(
|
||||
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor | None
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, residual = rms_norm_layer.forward_native(x, residual)
|
||||
@ -51,9 +50,9 @@ def ref_dynamic_per_token_quant(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
@ -76,9 +75,9 @@ def ref_impl(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ref_dynamic_per_token_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub
|
||||
)
|
||||
@ -88,9 +87,9 @@ def ops_dynamic_per_token_quant(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
@ -103,9 +102,9 @@ def ops_impl(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: Optional[torch.Tensor],
|
||||
scale_ub: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -68,7 +68,7 @@ def test_rotary_embedding(
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
rotary_dim: int | None,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
|
||||
@ -4,8 +4,6 @@
|
||||
Tests for miscellaneous utilities
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -17,7 +15,7 @@ def rotary_embedding_opcheck(
|
||||
rot,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
key: torch.Tensor | None = None,
|
||||
):
|
||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -19,11 +18,11 @@ from vllm.platforms import current_platform
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
initial_states: torch.Tensor | None = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
final_states_out: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
@ -117,12 +116,12 @@ def causal_conv1d_update_ref(
|
||||
def causal_conv1d_opcheck_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cu_seq_len: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
bias: torch.Tensor | None = None,
|
||||
cu_seq_len: torch.Tensor | None = None,
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
conv_states: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +35,7 @@ from .mk_objects import (
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
@ -44,21 +44,21 @@ def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: Union[list[int], int]
|
||||
Ms: list[int] | int
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
topks: list[int] | int
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[TestMoEQuantConfig]
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: Optional[int]
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
torch_trace_dir_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
@ -93,7 +93,7 @@ class Config:
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@ -112,7 +112,7 @@ class Config:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
def quant_block_shape(self) -> list[int] | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@ -209,7 +209,7 @@ class Config:
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self) -> tuple[bool, Optional[str]]:
|
||||
def is_valid(self) -> tuple[bool, str | None]:
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
@ -280,10 +280,10 @@ class Config:
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
w1_gs: Optional[torch.Tensor] = None
|
||||
w2_gs: Optional[torch.Tensor] = None
|
||||
w1_scale: torch.Tensor | None
|
||||
w2_scale: torch.Tensor | None
|
||||
w1_gs: torch.Tensor | None = None
|
||||
w2_gs: torch.Tensor | None = None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
@ -351,11 +351,11 @@ class WeightTensors:
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
hidden_states_scale: torch.Tensor | None
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
expert_map: torch.Tensor | None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
@ -370,7 +370,7 @@ class RankTensors:
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@ -82,7 +81,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(
|
||||
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
|
||||
config: Config, success: Result, results_df: pd.DataFrame | None = None
|
||||
):
|
||||
config_dict = asdict(config)
|
||||
config_dict["prepare_finalize_type"] = config_dict[
|
||||
@ -121,7 +120,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
|
||||
)
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
results_df: pd.DataFrame | None = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations
|
||||
):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -43,25 +42,25 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: Union[torch.dtype, str, None]
|
||||
quant_dtype: torch.dtype | str | None
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: Optional[list[int]]
|
||||
block_shape: list[int] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
backend: Optional[str]
|
||||
backend: str | None
|
||||
supports_apply_weight_on_input: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
@ -78,7 +77,7 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[Union[torch.dtype, str]] = [
|
||||
common_float_types: list[torch.dtype | str] = [
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
@ -92,9 +91,9 @@ fp8_types = [torch.float8_e4m3fn]
|
||||
def register_prepare_and_finalize(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
backend: Optional[str],
|
||||
backend: str | None,
|
||||
force_multigpu: bool = False,
|
||||
supports_apply_weight_on_input: bool = True,
|
||||
):
|
||||
@ -121,7 +120,7 @@ def register_prepare_and_finalize(
|
||||
def register_experts(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
@ -340,7 +339,7 @@ if cutlass_fp4_supported():
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
@ -395,7 +394,7 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: Optional[str],
|
||||
backend: str | None,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
|
||||
@ -3,11 +3,12 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Concatenate
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
@ -58,9 +59,9 @@ def _worker_parallel_launch(
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
|
||||
vllm_config: VllmConfig | None,
|
||||
env_dict: dict | None,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
|
||||
@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -7,12 +7,13 @@ DeepEP test utilities
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Concatenate
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
|
||||
@ -126,8 +127,8 @@ def make_deepep_ht_a2a(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
@ -153,8 +154,8 @@ def make_deepep_ll_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
@ -185,10 +186,10 @@ def make_deepep_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
deepep_ht_args: DeepEPHTArgs | None,
|
||||
deepep_ll_args: DeepEPLLArgs | None,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -55,7 +54,7 @@ vllm_config.scheduler_config.max_model_len = 8192
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
in_dtype: torch.dtype
|
||||
quant_dtype: Optional[torch.dtype]
|
||||
quant_dtype: torch.dtype | None
|
||||
out_dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
@ -115,7 +114,7 @@ def test_batched_mm(
|
||||
K: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
@ -242,7 +241,7 @@ def test_fused_moe_batched_experts(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -5,7 +5,6 @@ Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -16,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor] = None
|
||||
expert_map: torch.Tensor | None = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -85,16 +84,16 @@ class MOETensors:
|
||||
@dataclasses.dataclass
|
||||
class MOETensors8Bit(MOETensors):
|
||||
# quantized
|
||||
a_q: Optional[torch.Tensor] = None # a -> a_q
|
||||
w1_q: Optional[torch.Tensor] = None # w1 -> w1_q
|
||||
w2_q: Optional[torch.Tensor] = None # w2 -> w2_q
|
||||
a_scale: Optional[torch.Tensor] = None
|
||||
w1_scale: Optional[torch.Tensor] = None
|
||||
w2_scale: Optional[torch.Tensor] = None
|
||||
a_q: torch.Tensor | None = None # a -> a_q
|
||||
w1_q: torch.Tensor | None = None # w1 -> w1_q
|
||||
w2_q: torch.Tensor | None = None # w2 -> w2_q
|
||||
a_scale: torch.Tensor | None = None
|
||||
w1_scale: torch.Tensor | None = None
|
||||
w2_scale: torch.Tensor | None = None
|
||||
# dequantized
|
||||
a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d
|
||||
w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d
|
||||
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
|
||||
a_d: torch.Tensor | None = None # a -> a_q -> a_d
|
||||
w1_d: torch.Tensor | None = None # w1 -> w1_q -> w1_d
|
||||
w2_d: torch.Tensor | None = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
@ -209,7 +208,7 @@ def run_8_bit(
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None,
|
||||
num_local_experts: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not any(
|
||||
[
|
||||
@ -280,7 +279,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
ep_size: Optional[int] = None,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
|
||||
@ -7,7 +7,6 @@ fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@ -92,13 +91,13 @@ class TestConfig:
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
use_fp8_dispatch: Optional[bool] = False
|
||||
use_fp8_dispatch: bool | None = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: Optional[torch.Tensor]
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
@ -143,7 +142,7 @@ def make_ll_modular_kernel(
|
||||
max_tokens_per_rank: int,
|
||||
dp_size: int,
|
||||
hidden_size: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
@ -179,7 +178,7 @@ def make_ht_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
@ -249,8 +248,8 @@ def deepep_deepgemm_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
test_config = test_tensors.config
|
||||
num_experts = test_config.num_experts
|
||||
|
||||
@ -5,7 +5,6 @@ Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@ -90,7 +89,7 @@ class TestConfig:
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: Optional[torch.Tensor]
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
@ -128,12 +127,12 @@ def make_modular_kernel(
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
q_dtype: torch.dtype | None,
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
ht_args: DeepEPHTArgs | None = None
|
||||
ll_args: DeepEPLLArgs | None = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(
|
||||
@ -148,16 +147,14 @@ def make_modular_kernel(
|
||||
)
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
|
||||
make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
@ -184,8 +181,8 @@ def deep_ep_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
@ -281,8 +278,8 @@ def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
@ -340,8 +337,8 @@ def _deep_ep_moe(
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
@ -5,7 +5,7 @@ import copy
|
||||
import textwrap
|
||||
import traceback
|
||||
from itertools import product
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -245,10 +245,10 @@ def test_modular_kernel_combinations_multigpu(
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: Optional[int],
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
@ -287,10 +287,10 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: Optional[int],
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
|
||||
@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -80,7 +80,7 @@ vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def run_moe_test(
|
||||
baseline: Union[Callable, torch.Tensor],
|
||||
baseline: Callable | torch.Tensor,
|
||||
moe_fn: Callable,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -88,7 +88,7 @@ def run_moe_test(
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
padding: bool = False,
|
||||
use_compile: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
@ -212,7 +212,7 @@ def test_fused_moe(
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
return m_fused_moe_fn(
|
||||
|
||||
@ -5,8 +5,6 @@
|
||||
Run `pytest tests/kernels/moe/test_moe_align_block_size.py`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -94,7 +92,7 @@ def torch_moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
||||
@ -5,8 +5,6 @@
|
||||
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@ -34,8 +32,8 @@ def torch_permute(
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
@ -210,7 +208,7 @@ def test_moe_permute_unpermute(
|
||||
n_expert: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
align_block_size: Optional[int],
|
||||
align_block_size: int | None,
|
||||
):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -103,7 +102,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
@ -510,7 +509,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: Optional[float],
|
||||
limit: float | None,
|
||||
act_type: str,
|
||||
transpose_optimized: bool,
|
||||
):
|
||||
@ -660,7 +659,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: Optional[float],
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
@ -811,9 +810,9 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: Optional[float],
|
||||
beta: Optional[float],
|
||||
limit: Optional[float],
|
||||
alpha: float | None,
|
||||
beta: float | None,
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -73,7 +72,7 @@ def pplx_cutlass_moe(
|
||||
out_dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
group_name: Optional[str],
|
||||
group_name: str | None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
|
||||
@ -9,7 +9,7 @@ import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
import traceback
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -89,7 +89,7 @@ def torch_prepare(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
max_num_tokens: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.shape[0] == a.shape[0]
|
||||
@ -214,10 +214,10 @@ def create_pplx_prepare_finalize(
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
block_shape: Optional[list[int]],
|
||||
quant_dtype: torch.dtype | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
group_name: Optional[str],
|
||||
group_name: str | None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
@ -274,18 +274,14 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
|
||||
|
||||
def maybe_chunk_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
def maybe_chunk_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
|
||||
if t is not None:
|
||||
return chunk_by_rank(t, r, w)
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
def chunk_scales_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
|
||||
if t is not None and t.numel() > 1:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
@ -293,9 +289,7 @@ def chunk_scales_by_rank(
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
t: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
def chunk_scales(t: torch.Tensor | None, start: int, end: int) -> torch.Tensor | None:
|
||||
if t is not None and t.numel() > 1:
|
||||
return t[start:end]
|
||||
else:
|
||||
@ -313,10 +307,10 @@ def pplx_prepare_finalize(
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
block_shape: Optional[list[int]],
|
||||
quant_dtype: torch.dtype | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
group_name: Optional[str],
|
||||
group_name: str | None,
|
||||
) -> torch.Tensor:
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
@ -409,8 +403,8 @@ def _pplx_prepare_finalize(
|
||||
score: torch.Tensor,
|
||||
topk: torch.Tensor,
|
||||
num_experts: int,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
block_shape: Optional[list[int]],
|
||||
quant_dtype: torch.dtype | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
use_internode: bool,
|
||||
):
|
||||
@ -479,7 +473,7 @@ def test_pplx_prepare_finalize_slow(
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
use_internode: bool,
|
||||
):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
@ -521,7 +515,7 @@ def test_pplx_prepare_finalize_slow(
|
||||
|
||||
|
||||
def pplx_moe(
|
||||
group_name: Optional[str],
|
||||
group_name: str | None,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
@ -530,17 +524,17 @@ def pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
num_tokens, hidden_dim = a.shape
|
||||
num_experts = w1.shape[0]
|
||||
topk = topk_ids.shape[1]
|
||||
@ -657,13 +651,13 @@ def _pplx_moe(
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_s: torch.Tensor | None = None,
|
||||
w2_s: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
use_internode: bool = False,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
@ -812,7 +806,7 @@ def test_pplx_moe_slow(
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
block_shape: list[int] | None,
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -27,13 +26,13 @@ def triton_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
@ -54,13 +53,13 @@ def batched_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
@ -94,13 +93,13 @@ def naive_batched_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
@ -129,8 +128,8 @@ def naive_batched_moe(
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@ -144,10 +143,10 @@ def make_quantized_test_activations(
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
@ -172,11 +171,11 @@ def make_quantized_test_activations(
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: Optional[torch.Tensor],
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
w_s: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
assert (
|
||||
quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8
|
||||
@ -220,10 +219,10 @@ def make_test_weight(
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
@ -262,12 +261,12 @@ def make_test_weights(
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(
|
||||
@ -295,9 +294,9 @@ def make_test_quant_config(
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
@ -310,8 +309,8 @@ def make_test_quant_config(
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
a1_gscale: torch.Tensor | None = None
|
||||
a2_gscale: torch.Tensor | None = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
@ -348,9 +347,9 @@ def fused_moe(
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states, score.float(), topk, renormalize
|
||||
@ -378,7 +377,7 @@ class BaselineMM(torch.nn.Module):
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
@ -422,8 +421,8 @@ class RealMLP(torch.nn.Module):
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
w1_s: torch.Tensor | None = None,
|
||||
w2_s: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@ -481,7 +480,7 @@ def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,13 +14,13 @@ ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def ref_dynamic_per_token_quant(
|
||||
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
@ -76,8 +75,8 @@ def ref_dynamic_per_token_quant(
|
||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||
# kernel
|
||||
def ref_dynamic_per_tensor_fp8_quant(
|
||||
x: torch.tensor,
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
@ -250,10 +249,10 @@ def per_block_cast_to_int8(
|
||||
|
||||
def dequant(
|
||||
t: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
f32 = torch.float32
|
||||
@ -267,10 +266,10 @@ def dequant(
|
||||
|
||||
def batched_dequant(
|
||||
t: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
scale: torch.Tensor | None,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||
out_dtype: torch.dtype | None = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
assert t.shape[0] == scale.shape[0]
|
||||
@ -289,9 +288,9 @@ def native_batched_masked_quant_matmul(
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor] = None,
|
||||
B_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
A_scale: torch.Tensor | None = None,
|
||||
B_scale: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
|
||||
@ -6,7 +6,6 @@ Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -60,10 +59,10 @@ SCHEDULES = [
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -80,7 +79,7 @@ class Tensors:
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
@ -116,8 +115,8 @@ def cutlass_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
@ -143,7 +142,7 @@ def cutlass_quantize_and_pack(
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
@ -185,8 +184,8 @@ def create_test_tensors(
|
||||
def mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
|
||||
@ -7,7 +7,6 @@ Run `pytest tests/kernels/quantization/test_machete_mm.py`.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -50,11 +49,11 @@ MNK_SHAPES = [
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
group_zero_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
group_zero_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -63,10 +62,10 @@ class Tensors:
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: Optional[torch.Tensor]
|
||||
w_g_zp: Optional[torch.Tensor]
|
||||
w_ch_s: Optional[torch.Tensor]
|
||||
w_tok_s: Optional[torch.Tensor]
|
||||
w_g_s: torch.Tensor | None
|
||||
w_g_zp: torch.Tensor | None
|
||||
w_ch_s: torch.Tensor | None
|
||||
w_tok_s: torch.Tensor | None
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
@ -74,7 +73,7 @@ class Tensors:
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
@ -139,11 +138,11 @@ def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
|
||||
return group_size is None or group_size == -1 or shape[2] % group_size == 0
|
||||
|
||||
|
||||
@ -151,8 +150,8 @@ def machete_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
@ -178,8 +177,8 @@ def machete_quantize_and_pack(
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None,
|
||||
group_size: int | None,
|
||||
subset_stride_factor: int | None = None,
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
@ -243,8 +242,8 @@ def create_test_tensors(
|
||||
def machete_mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
@ -294,7 +293,7 @@ def machete_mm_test_helper(
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
@ -323,7 +322,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
|
||||
@ -6,7 +6,6 @@ Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -27,7 +26,7 @@ def torch_scaled_mm(
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||
out = scale_a * out
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for FlexAttention backend vs default backend"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -38,8 +36,8 @@ def ref_int8_scaled_mm(
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
azp: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
azp: torch.Tensor | None,
|
||||
bias: torch.Tensor | None,
|
||||
output_type: torch.dtype,
|
||||
):
|
||||
if azp is not None:
|
||||
|
||||
@ -7,7 +7,7 @@ import random
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from numbers import Number
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -96,10 +96,10 @@ class PackedQKVInputs(NamedTuple):
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
q_start_loc_list: Optional[list[int]]
|
||||
kv_start_loc_list: Optional[list[int]]
|
||||
q_seq_lens: Optional[list[int]]
|
||||
kv_seq_lens: Optional[list[int]]
|
||||
q_start_loc_list: list[int] | None
|
||||
kv_start_loc_list: list[int] | None
|
||||
q_seq_lens: list[int] | None
|
||||
kv_seq_lens: list[int] | None
|
||||
|
||||
|
||||
class PackedQKVO(NamedTuple):
|
||||
@ -115,7 +115,7 @@ class PackedQKVO(NamedTuple):
|
||||
x head_size) known-correct attention output
|
||||
"""
|
||||
|
||||
packed_qkv: Optional[PackedQKVInputs]
|
||||
packed_qkv: PackedQKVInputs | None
|
||||
ideal_output: torch.Tensor
|
||||
|
||||
|
||||
@ -149,12 +149,12 @@ class PhaseTestParameters(NamedTuple):
|
||||
"""
|
||||
|
||||
packed_qkvo: PackedQKVO
|
||||
kv_mmap: Optional[KVMemoryMap]
|
||||
kv_mmap: KVMemoryMap | None
|
||||
|
||||
|
||||
def maybe_make_int_tensor(
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
_list: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Python int list to a 1D int torch.Tensor on `device`
|
||||
@ -170,8 +170,8 @@ def maybe_make_int_tensor(
|
||||
|
||||
|
||||
def maybe_make_long_tensor(
|
||||
_list: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
_list: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert Python int list to a 1D long torch.Tensor on `device`
|
||||
@ -186,7 +186,7 @@ def maybe_make_long_tensor(
|
||||
)
|
||||
|
||||
|
||||
def maybe_max(_list: Optional[list]) -> Optional[Number]:
|
||||
def maybe_max(_list: list | None) -> Number | None:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
@ -241,9 +241,9 @@ def ref_masked_attention(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
custom_mask: Optional[torch.Tensor] = None,
|
||||
q_seq_lens: Optional[list] = None,
|
||||
kv_seq_lens: Optional[list] = None,
|
||||
custom_mask: torch.Tensor | None = None,
|
||||
q_seq_lens: list | None = None,
|
||||
kv_seq_lens: list | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
"Golden" masked attention reference. Supports two types of masking:
|
||||
@ -302,11 +302,11 @@ def ref_masked_attention(
|
||||
def make_qkv(
|
||||
batch_size: int,
|
||||
max_q_seq_len: int,
|
||||
max_kv_seq_len: Optional[int],
|
||||
max_kv_seq_len: int | None,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
device: Union[torch.device, str],
|
||||
force_kv_seq_lens: Optional[list[int]] = None,
|
||||
device: torch.device | str,
|
||||
force_kv_seq_lens: list[int] | None = None,
|
||||
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||
force_max_len: bool = False,
|
||||
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||
@ -436,7 +436,7 @@ def make_qkv(
|
||||
|
||||
|
||||
def pack_tensor(
|
||||
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str]
|
||||
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
"""
|
||||
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||
@ -470,7 +470,7 @@ def pack_tensor(
|
||||
return packed_tensor, start_loc_list
|
||||
|
||||
|
||||
def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs:
|
||||
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
|
||||
"""
|
||||
Individually pack each of Q, K and V, each with dimensions batch_size x
|
||||
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
||||
@ -594,19 +594,19 @@ def make_alibi_bias(
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[list[int]],
|
||||
context_lens: Optional[list[int]],
|
||||
encoder_seq_lens: Optional[list[int]],
|
||||
device: Union[torch.device, str],
|
||||
seq_lens: list[int] | None,
|
||||
context_lens: list[int] | None,
|
||||
encoder_seq_lens: list[int] | None,
|
||||
device: torch.device | str,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Any,
|
||||
Any,
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Optional[int],
|
||||
int | None,
|
||||
]:
|
||||
"""
|
||||
Build scalar & tensor values required to build attention metadata structure.
|
||||
@ -678,7 +678,7 @@ def make_kv_cache(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
backend: str,
|
||||
default_val: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
@ -726,18 +726,18 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
|
||||
return (num_tokens + block_size) // block_size
|
||||
|
||||
|
||||
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
|
||||
def make_empty_slot_mapping_tensor(device: torch.device | str):
|
||||
return maybe_make_long_tensor([], device)
|
||||
|
||||
|
||||
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
||||
def make_empty_block_tables_tensor(device: torch.device | str):
|
||||
return torch.tensor([], device=device)
|
||||
|
||||
|
||||
def split_slot_mapping(
|
||||
slot_mapping_list: torch.Tensor,
|
||||
seq_lens: list[int],
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
):
|
||||
"""
|
||||
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||
@ -799,7 +799,7 @@ def split_slot_mapping(
|
||||
def make_block_tables_slot_mapping(
|
||||
block_size: int,
|
||||
seq_lens: list[int],
|
||||
device: Union[torch.device, str],
|
||||
device: torch.device | str,
|
||||
block_base_addr: int = 0,
|
||||
) -> tuple[torch.Tensor, list[int], int]:
|
||||
"""
|
||||
@ -880,11 +880,11 @@ def make_block_tables_slot_mapping(
|
||||
def make_test_metadata(
|
||||
attn_backend: _Backend,
|
||||
is_prompt: bool,
|
||||
seq_lens: Optional[list[int]],
|
||||
decoder_test_params: Optional[PhaseTestParameters],
|
||||
device: Union[torch.device, str],
|
||||
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||
cross_test_params: Optional[PhaseTestParameters] = None,
|
||||
seq_lens: list[int] | None,
|
||||
decoder_test_params: PhaseTestParameters | None,
|
||||
device: torch.device | str,
|
||||
encoder_test_params: PhaseTestParameters | None = None,
|
||||
cross_test_params: PhaseTestParameters | None = None,
|
||||
) -> AttentionMetadata:
|
||||
"""
|
||||
Construct fake attention metadata for a given test phase
|
||||
@ -1142,16 +1142,16 @@ def torch_experts(
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
b_bias1: torch.Tensor | None = None,
|
||||
b_bias2: torch.Tensor | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_shape: list[int] | None = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
@ -1261,10 +1261,10 @@ def torch_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
b_bias1: torch.Tensor | None = None,
|
||||
b_bias2: torch.Tensor | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
@ -1298,15 +1298,13 @@ def torch_moe_single(a, w, score, topk):
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(
|
||||
op: Union[
|
||||
torch._ops.OpOverload,
|
||||
torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef,
|
||||
],
|
||||
op: torch._ops.OpOverload
|
||||
| torch._ops.OpOverloadPacket
|
||||
| torch._library.custom_ops.CustomOpDef,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True,
|
||||
) -> dict[str, str]:
|
||||
@ -1338,7 +1336,7 @@ def baseline_scaled_mm(
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match
|
||||
|
||||
Reference in New Issue
Block a user