Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -27,8 +26,8 @@ def ref_paged_attn(
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
@ -94,12 +93,12 @@ def test_varlen_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
num_blocks: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional
import pytest
import torch
@ -50,7 +49,7 @@ def ref_masked_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
@ -69,7 +68,7 @@ def ref_single_query_cached_kv_attention(
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
alibi_slopes: torch.Tensor | None,
) -> None:
num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
@ -415,7 +414,7 @@ def ref_multi_query_kv_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
alibi_bias: Optional[list[torch.Tensor]],
alibi_bias: list[torch.Tensor] | None,
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -85,7 +84,7 @@ def test_cascade(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
num_blocks: int,
fa_version: int,
) -> None:

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import random
from typing import Optional
import pytest
import torch
@ -17,7 +16,7 @@ def cal_diff(
y: torch.Tensor,
name: str,
use_fp8: bool = False,
diff_threshold: Optional[float] = None,
diff_threshold: float | None = None,
) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -34,8 +33,8 @@ def ref_paged_attn(
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
@ -103,11 +102,11 @@ def test_flash_attn_with_paged_kv(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
num_blocks: int,
sliding_window: Optional[int],
sliding_window: int | None,
fa_version: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
@ -221,13 +220,13 @@ def test_varlen_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
num_blocks: int,
fa_version: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import flashinfer
import pytest
@ -26,8 +25,8 @@ def ref_paged_attn(
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
@ -90,8 +89,8 @@ def test_flashinfer_decode_with_paged_kv(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
soft_cap: float | None,
sliding_window: int | None,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
@ -185,8 +184,8 @@ def test_flashinfer_prefill_with_paged_kv(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
soft_cap: float | None,
sliding_window: int | None,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
@ -288,7 +287,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
) -> None:
pytest.skip("TODO: fix the accuracy issue")
torch.set_default_device("cuda")
@ -398,7 +397,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
) -> None:
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import flashinfer
import pytest
@ -68,9 +67,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@ -78,7 +75,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
soft_cap: float | None,
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")
@ -267,9 +264,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@ -277,7 +272,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
soft_cap: float | None,
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -20,7 +19,7 @@ def merge_attn_states_torch(
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
):
p_lse = prefix_lse
s_lse = suffix_lse

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -32,8 +31,8 @@ def ref_paged_attn(
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
@ -98,12 +97,12 @@ def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
soft_cap: float | None,
num_blocks: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
) -> None:
torch.set_default_device("cuda")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import pytest
import torch
@ -31,13 +30,13 @@ EPS = 1e-6
## Helpers
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_rms_norm(
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor | None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if residual is not None:
residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual)
@ -51,9 +50,9 @@ def ref_dynamic_per_token_quant(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
@ -76,9 +75,9 @@ def ref_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ref_dynamic_per_token_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub
)
@ -88,9 +87,9 @@ def ops_dynamic_per_token_quant(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(
@ -103,9 +102,9 @@ def ops_impl(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from itertools import product
from typing import Callable, Optional
import pytest
import torch
@ -68,7 +68,7 @@ def test_rotary_embedding(
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
rotary_dim: int | None,
dtype: torch.dtype,
seed: int,
device: str,

View File

@ -4,8 +4,6 @@
Tests for miscellaneous utilities
"""
from typing import Optional
import pytest
import torch
@ -17,7 +15,7 @@ def rotary_embedding_opcheck(
rot,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
key: torch.Tensor | None = None,
):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -19,11 +18,11 @@ from vllm.platforms import current_platform
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
initial_states: torch.Tensor | None = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
final_states_out: torch.Tensor | None = None,
activation: str | None = "silu",
):
"""
x: (batch, dim, seqlen)
@ -117,12 +116,12 @@ def causal_conv1d_update_ref(
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
bias: torch.Tensor | None = None,
cu_seq_len: torch.Tensor | None = None,
cache_indices: torch.Tensor | None = None,
has_initial_state: torch.Tensor | None = None,
conv_states: torch.Tensor | None = None,
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any
import torch
@ -35,7 +35,7 @@ from .mk_objects import (
from .parallel_utils import ProcessGroupInfo
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
if t is None:
return f"{name} : None"
else:
@ -44,21 +44,21 @@ def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
@dataclass
class Config:
Ms: Union[list[int], int]
Ms: list[int] | int
K: int
N: int
E: int
topks: Union[list[int], int]
topks: list[int] | int
dtype: torch.dtype
quant_config: Optional[TestMoEQuantConfig]
quant_config: TestMoEQuantConfig | None
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
fused_moe_chunk_size: Optional[int]
fused_moe_chunk_size: int | None
world_size: int
torch_trace_dir_path: Optional[str] = None
torch_trace_dir_path: str | None = None
def __post_init__(self):
if self.quant_config is None:
@ -93,7 +93,7 @@ class Config:
return self.Ms
@property
def quant_dtype(self) -> Union[torch.dtype, str, None]:
def quant_dtype(self) -> torch.dtype | str | None:
assert self.quant_config is not None
return self.quant_config.quant_dtype
@ -112,7 +112,7 @@ class Config:
return self.quant_config.per_out_ch_quant
@property
def quant_block_shape(self) -> Optional[list[int]]:
def quant_block_shape(self) -> list[int] | None:
assert self.quant_config is not None
return self.quant_config.block_shape
@ -209,7 +209,7 @@ class Config:
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend
def is_valid(self) -> tuple[bool, Optional[str]]:
def is_valid(self) -> tuple[bool, str | None]:
# Check prepare-finalize and fused-experts compatibility
if self.is_batched_prepare_finalize():
if not self.is_batched_fused_experts():
@ -280,10 +280,10 @@ class Config:
class WeightTensors:
w1: torch.Tensor
w2: torch.Tensor
w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor]
w1_gs: Optional[torch.Tensor] = None
w2_gs: Optional[torch.Tensor] = None
w1_scale: torch.Tensor | None
w2_scale: torch.Tensor | None
w1_gs: torch.Tensor | None = None
w2_gs: torch.Tensor | None = None
def describe(self):
s = ""
@ -351,11 +351,11 @@ class WeightTensors:
@dataclass
class RankTensors:
hidden_states: torch.Tensor
hidden_states_scale: Optional[torch.Tensor]
hidden_states_scale: torch.Tensor | None
topk_weights: torch.Tensor
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
expert_map: torch.Tensor | None
def describe(self):
s = ""
@ -370,7 +370,7 @@ class RankTensors:
@staticmethod
def make_hidden_states(
config: Config,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Return hidden_states
"""

View File

@ -4,7 +4,6 @@
import copy
from enum import Enum
from itertools import product
from typing import Optional
import torch
from tqdm import tqdm
@ -82,7 +81,7 @@ def make_feature_matrix(csv_file_path: str):
import pandas as pd
def add_to_results(
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
config: Config, success: Result, results_df: pd.DataFrame | None = None
):
config_dict = asdict(config)
config_dict["prepare_finalize_type"] = config_dict[
@ -121,7 +120,7 @@ def make_feature_matrix(csv_file_path: str):
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
)
results_df: Optional[pd.DataFrame] = None
results_df: pd.DataFrame | None = None
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
combinations
):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch
@ -43,25 +42,25 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@dataclass
class TestMoEQuantConfig:
quant_dtype: Union[torch.dtype, str, None]
quant_dtype: torch.dtype | str | None
per_out_ch_quant: bool
per_act_token_quant: bool
block_shape: Optional[list[int]]
block_shape: list[int] | None
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool
backend: Optional[str]
backend: str | None
supports_apply_weight_on_input: bool = True
@dataclass
class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
@ -78,7 +77,7 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
common_float_types: list[torch.dtype | str] = [
torch.float8_e4m3fn,
torch.bfloat16,
torch.float16,
@ -92,9 +91,9 @@ fp8_types = [torch.float8_e4m3fn]
def register_prepare_and_finalize(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool,
backend: Optional[str],
backend: str | None,
force_multigpu: bool = False,
supports_apply_weight_on_input: bool = True,
):
@ -121,7 +120,7 @@ def register_prepare_and_finalize(
def register_experts(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
@ -340,7 +339,7 @@ if cutlass_fp4_supported():
supports_expert_map=False,
)
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
None,
# per-channel / per-column weights and per-tensor activations
TestMoEQuantConfig(
@ -395,7 +394,7 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
backend: str | None,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:

View File

@ -3,11 +3,12 @@
import dataclasses
import os
import traceback
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Concatenate
import torch
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import init_distributed_environment, initialize_model_parallel
@ -58,9 +59,9 @@ def _worker_parallel_launch(
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
vllm_config: Optional[VllmConfig],
env_dict: Optional[dict],
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
vllm_config: VllmConfig | None,
env_dict: dict | None,
*args: P.args,
**kwargs: P.kwargs,
) -> None:

View File

@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Callable
from itertools import product
from typing import Any, Callable
from typing import Any
import torch

View File

@ -7,12 +7,13 @@ DeepEP test utilities
import dataclasses
import os
import traceback
from typing import Callable, Optional
from collections.abc import Callable
from typing import Concatenate
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
from vllm.utils import get_open_port, has_deep_ep
@ -126,8 +127,8 @@ def make_deepep_ht_a2a(
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
q_dtype: torch.dtype | None = None,
block_shape: list[int] | None = None,
):
import deep_ep
@ -153,8 +154,8 @@ def make_deepep_ll_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
q_dtype: torch.dtype | None = None,
block_shape: list[int] | None = None,
):
import deep_ep
@ -185,10 +186,10 @@ def make_deepep_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
deepep_ht_args: DeepEPHTArgs | None,
deepep_ll_args: DeepEPLLArgs | None,
q_dtype: torch.dtype | None = None,
block_shape: list[int] | None = None,
):
if deepep_ht_args is not None:
assert deepep_ll_args is None

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
@ -55,7 +54,7 @@ vllm_config.scheduler_config.max_model_len = 8192
@dataclass
class BatchedMMConfig:
in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
quant_dtype: torch.dtype | None
out_dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
@ -115,7 +114,7 @@ def test_batched_mm(
K: int,
N: int,
dtype: torch.dtype,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
per_act_token_quant: bool,
):
current_platform.seed_everything(7)
@ -242,7 +241,7 @@ def test_fused_moe_batched_experts(
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
input_scales: bool,
):
current_platform.seed_everything(7)

View File

@ -5,7 +5,6 @@ Tests compute_expert_num_tokens kernels
"""
import dataclasses
from typing import Optional
import pytest
import torch
@ -16,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
@dataclasses.dataclass
class TestTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor] = None
expert_map: torch.Tensor | None = None
def to_device(self, device: str):
self.topk_ids = self.topk_ids.to(device=device)

View File

@ -3,7 +3,6 @@
import copy
import dataclasses
from math import prod
from typing import Optional
import pytest
import torch
@ -85,16 +84,16 @@ class MOETensors:
@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
# quantized
a_q: Optional[torch.Tensor] = None # a -> a_q
w1_q: Optional[torch.Tensor] = None # w1 -> w1_q
w2_q: Optional[torch.Tensor] = None # w2 -> w2_q
a_scale: Optional[torch.Tensor] = None
w1_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
a_q: torch.Tensor | None = None # a -> a_q
w1_q: torch.Tensor | None = None # w1 -> w1_q
w2_q: torch.Tensor | None = None # w2 -> w2_q
a_scale: torch.Tensor | None = None
w1_scale: torch.Tensor | None = None
w2_scale: torch.Tensor | None = None
# dequantized
a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d
w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
a_d: torch.Tensor | None = None # a -> a_q -> a_d
w1_d: torch.Tensor | None = None # w1 -> w1_q -> w1_d
w2_d: torch.Tensor | None = None # w2 -> w2_q -> w2_d
@staticmethod
def make_moe_tensors_8bit(
@ -209,7 +208,7 @@ def run_8_bit(
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None,
num_local_experts: int | None = None,
) -> torch.Tensor:
assert not any(
[
@ -280,7 +279,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
ep_size: Optional[int] = None,
ep_size: int | None = None,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")

View File

@ -7,7 +7,6 @@ fp8 block-quantized case.
"""
import dataclasses
from typing import Optional
import pytest
import torch.distributed
@ -92,13 +91,13 @@ class TestConfig:
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
use_fp8_dispatch: Optional[bool] = False
use_fp8_dispatch: bool | None = False
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
rank_token_scales: torch.Tensor | None
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@ -143,7 +142,7 @@ def make_ll_modular_kernel(
max_tokens_per_rank: int,
dp_size: int,
hidden_size: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
@ -179,7 +178,7 @@ def make_ht_modular_kernel(
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
@ -249,8 +248,8 @@ def deepep_deepgemm_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
) -> torch.Tensor:
test_config = test_tensors.config
num_experts = test_config.num_experts

View File

@ -5,7 +5,6 @@ Test deepep dispatch-combine logic
"""
import dataclasses
from typing import Optional, Union
import pytest
import torch.distributed
@ -90,7 +89,7 @@ class TestConfig:
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: Optional[torch.Tensor]
rank_token_scales: torch.Tensor | None
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@ -128,12 +127,12 @@ def make_modular_kernel(
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
q_dtype: torch.dtype | None,
use_fp8_dispatch: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
ht_args: DeepEPHTArgs | None = None
ll_args: DeepEPLLArgs | None = None
if low_latency_mode:
ll_args = DeepEPLLArgs(
@ -148,16 +147,14 @@ def make_modular_kernel(
)
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
q_dtype=q_dtype,
block_shape=None,
deepep_ht_args=ht_args,
deepep_ll_args=ll_args,
)
a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
q_dtype=q_dtype,
block_shape=None,
deepep_ht_args=ht_args,
deepep_ll_args=ll_args,
)
num_dispatchers = pgi.world_size // dp_size
@ -184,8 +181,8 @@ def deep_ep_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
@ -281,8 +278,8 @@ def torch_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
@ -340,8 +337,8 @@ def _deep_ep_moe(
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):

View File

@ -5,7 +5,7 @@ import copy
import textwrap
import traceback
from itertools import product
from typing import Any, Optional
from typing import Any
import pytest
import torch
@ -245,10 +245,10 @@ def test_modular_kernel_combinations_multigpu(
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
chunk_size: int | None,
world_size: int,
pytestconfig,
):
@ -287,10 +287,10 @@ def test_modular_kernel_combinations_singlegpu(
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
quant_config: TestMoEQuantConfig | None,
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
chunk_size: int | None,
world_size: int,
pytestconfig,
):

View File

@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union
from collections.abc import Callable
import pytest
import torch
@ -80,7 +80,7 @@ vllm_config.scheduler_config.max_model_len = 8192
def run_moe_test(
baseline: Union[Callable, torch.Tensor],
baseline: Callable | torch.Tensor,
moe_fn: Callable,
a: torch.Tensor,
w1: torch.Tensor,
@ -88,7 +88,7 @@ def run_moe_test(
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
padding: bool = False,
use_compile: bool = False,
use_cudagraph: bool = False,
@ -212,7 +212,7 @@ def test_fused_moe(
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(

View File

@ -5,8 +5,6 @@
Run `pytest tests/kernels/moe/test_moe_align_block_size.py`.
"""
from typing import Optional
import pytest
import torch
@ -94,7 +92,7 @@ def torch_moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""

View File

@ -5,8 +5,6 @@
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""
from typing import Optional
import numpy as np
import pytest
import torch
@ -34,8 +32,8 @@ def torch_permute(
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
@ -210,7 +208,7 @@ def test_moe_permute_unpermute(
n_expert: int,
ep_size: int,
dtype: torch.dtype,
align_block_size: Optional[int],
align_block_size: int | None,
):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")

View File

@ -4,7 +4,6 @@
import importlib.metadata
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional
import pytest
import torch
@ -103,7 +102,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
assert output
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
@ -510,7 +509,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
limit: float | None,
act_type: str,
transpose_optimized: bool,
):
@ -660,7 +659,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
hidden_size: int,
alpha: float,
beta: float,
limit: Optional[float],
limit: float | None,
):
torch.manual_seed(42)
device = "cuda:0"
@ -811,9 +810,9 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
num_tokens: int,
intermediate_size: int,
hidden_size: int,
alpha: Optional[float],
beta: Optional[float],
limit: Optional[float],
alpha: float | None,
beta: float | None,
limit: float | None,
):
torch.manual_seed(42)
device = "cuda:0"

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
@ -73,7 +72,7 @@ def pplx_cutlass_moe(
out_dtype,
per_act_token: bool,
per_out_ch: bool,
group_name: Optional[str],
group_name: str | None,
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize,

View File

@ -9,7 +9,7 @@ import copy
import itertools
import textwrap
import traceback
from typing import Callable, Optional, Union
from collections.abc import Callable
import pytest
import torch
@ -89,7 +89,7 @@ def torch_prepare(
a: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
max_num_tokens: Optional[int] = None,
max_num_tokens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert topk_ids.dim() == 2
assert topk_ids.shape[0] == a.shape[0]
@ -214,10 +214,10 @@ def create_pplx_prepare_finalize(
dp_size: int,
world_size: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
quant_dtype: torch.dtype | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
group_name: Optional[str],
group_name: str | None,
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize,
@ -274,18 +274,14 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
return t[(r * chunk) : (r + 1) * chunk]
def maybe_chunk_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
def maybe_chunk_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
if t is not None:
return chunk_by_rank(t, r, w)
else:
return t
def chunk_scales_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
def chunk_scales_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
if t is not None and t.numel() > 1:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk) : (r + 1) * chunk]
@ -293,9 +289,7 @@ def chunk_scales_by_rank(
return t
def chunk_scales(
t: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
def chunk_scales(t: torch.Tensor | None, start: int, end: int) -> torch.Tensor | None:
if t is not None and t.numel() > 1:
return t[start:end]
else:
@ -313,10 +307,10 @@ def pplx_prepare_finalize(
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
quant_dtype: torch.dtype | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
group_name: Optional[str],
group_name: str | None,
) -> torch.Tensor:
assert torch.cuda.current_device() == pgi.local_rank
@ -409,8 +403,8 @@ def _pplx_prepare_finalize(
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
quant_dtype: Optional[torch.dtype],
block_shape: Optional[list[int]],
quant_dtype: torch.dtype | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
use_internode: bool,
):
@ -479,7 +473,7 @@ def test_pplx_prepare_finalize_slow(
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
use_internode: bool,
):
if dtype == torch.float8_e4m3fn:
@ -521,7 +515,7 @@ def test_pplx_prepare_finalize_slow(
def pplx_moe(
group_name: Optional[str],
group_name: str | None,
rank: int,
world_size: int,
dp_size: int,
@ -530,17 +524,17 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
shared_experts: torch.nn.Module | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
topk = topk_ids.shape[1]
@ -657,13 +651,13 @@ def _pplx_moe(
score: torch.Tensor,
topk: int,
num_experts: int,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
w1_s: torch.Tensor | None = None,
w2_s: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
use_internode: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
shared_experts: torch.nn.Module | None = None,
):
try:
if use_internode:
@ -812,7 +806,7 @@ def test_pplx_moe_slow(
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
block_shape: list[int] | None,
use_internode: bool,
):
current_platform.seed_everything(7)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
@ -27,13 +26,13 @@ def triton_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
@ -54,13 +53,13 @@ def batched_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
@ -94,13 +93,13 @@ def naive_batched_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
@ -129,8 +128,8 @@ def naive_batched_moe(
def chunk_scales(
scales: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
if scales is not None:
if scales.numel() == 1:
return scales
@ -144,10 +143,10 @@ def make_quantized_test_activations(
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
quant_dtype: torch.dtype | None = None,
block_shape: list[int] | None = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
@ -172,11 +171,11 @@ def make_quantized_test_activations(
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Union[torch.dtype, str, None],
w_s: torch.Tensor | None,
quant_dtype: torch.dtype | str | None,
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
block_shape: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
assert (
quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8
@ -220,10 +219,10 @@ def make_test_weight(
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
quant_dtype: torch.dtype | str | None = None,
block_shape: list[int] | None = None,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
@ -262,12 +261,12 @@ def make_test_weights(
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
quant_dtype: torch.dtype | str | None = None,
block_shape: list[int] | None = None,
per_out_ch_quant: bool = False,
) -> tuple[
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
]:
return (
make_test_weight(
@ -295,9 +294,9 @@ def make_test_quant_config(
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
quant_dtype: torch.dtype | str | None = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
@ -310,8 +309,8 @@ def make_test_quant_config(
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
a1_gscale: torch.Tensor | None = None
a2_gscale: torch.Tensor | None = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
@ -348,9 +347,9 @@ def fused_moe(
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
quant_config: FusedMoEQuantConfig | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(
hidden_states, score.float(), topk, renormalize
@ -378,7 +377,7 @@ class BaselineMM(torch.nn.Module):
self.b = b.to(dtype=torch.float32)
self.out_dtype = out_dtype
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
@ -422,8 +421,8 @@ class RealMLP(torch.nn.Module):
quant_config=None,
reduce_results: bool = True,
prefix: str = "",
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
w1_s: torch.Tensor | None = None,
w2_s: torch.Tensor | None = None,
) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -481,7 +480,7 @@ def make_shared_experts(
N: int,
K: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
quant_dtype: torch.dtype | str | None = None,
) -> torch.nn.Module:
from vllm.model_executor.layers.quantization.fp8 import Fp8Config

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
@ -15,13 +14,13 @@ ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor:
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_dynamic_per_token_quant(
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
) -> tuple[torch.tensor, torch.tensor]:
x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
@ -76,8 +75,8 @@ def ref_dynamic_per_token_quant(
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(
x: torch.tensor,
) -> tuple[torch.tensor, torch.tensor]:
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
@ -250,10 +249,10 @@ def per_block_cast_to_int8(
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
scale: torch.Tensor | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
out_dtype: torch.dtype | None = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
@ -267,10 +266,10 @@ def dequant(
def batched_dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
scale: torch.Tensor | None,
block_shape: list[int] | None,
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
out_dtype: torch.dtype | None = torch.float32,
) -> torch.Tensor:
if scale is not None:
assert t.shape[0] == scale.shape[0]
@ -289,9 +288,9 @@ def native_batched_masked_quant_matmul(
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
A_scale: torch.Tensor | None = None,
B_scale: torch.Tensor | None = None,
block_shape: list[int] | None = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()

View File

@ -6,7 +6,6 @@ Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
"""
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
@ -60,10 +59,10 @@ SCHEDULES = [
class TypeConfig:
act_type: torch.dtype
weight_type: ScalarType
output_type: Optional[torch.dtype]
group_scale_type: Optional[torch.dtype]
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]
output_type: torch.dtype | None
group_scale_type: torch.dtype | None
channel_scale_type: torch.dtype | None
token_scale_type: torch.dtype | None
@dataclass
@ -80,7 +79,7 @@ class Tensors:
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
]
TEST_TYPES = [
*(
@ -116,8 +115,8 @@ def cutlass_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
@ -143,7 +142,7 @@ def cutlass_quantize_and_pack(
def create_test_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
) -> Tensors:
m, n, k = shape
@ -185,8 +184,8 @@ def create_test_tensors(
def mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
group_size: int | None = None,
schedule: str | None = None,
):
# CUTLASS upstream uses fp8 with fastaccum as reference
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406

View File

@ -7,7 +7,6 @@ Run `pytest tests/kernels/quantization/test_machete_mm.py`.
import math
from dataclasses import dataclass, fields
from typing import Optional
import pytest
import torch
@ -50,11 +49,11 @@ MNK_SHAPES = [
class TypeConfig:
act_type: torch.dtype
weight_type: ScalarType
output_type: Optional[torch.dtype]
group_scale_type: Optional[torch.dtype]
group_zero_type: Optional[torch.dtype]
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]
output_type: torch.dtype | None
group_scale_type: torch.dtype | None
group_zero_type: torch.dtype | None
channel_scale_type: torch.dtype | None
token_scale_type: torch.dtype | None
@dataclass
@ -63,10 +62,10 @@ class Tensors:
a_ref: torch.Tensor
a: torch.Tensor
w_q: torch.Tensor
w_g_s: Optional[torch.Tensor]
w_g_zp: Optional[torch.Tensor]
w_ch_s: Optional[torch.Tensor]
w_tok_s: Optional[torch.Tensor]
w_g_s: torch.Tensor | None
w_g_zp: torch.Tensor | None
w_ch_s: torch.Tensor | None
w_tok_s: torch.Tensor | None
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
@ -74,7 +73,7 @@ class Tensors:
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
]
TEST_TYPES = [
# GPTQ style
@ -139,11 +138,11 @@ def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
return group_size is None or group_size == -1 or shape[2] % group_size == 0
@ -151,8 +150,8 @@ def machete_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
stype: torch.dtype | None,
group_size: int | None,
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
@ -178,8 +177,8 @@ def machete_quantize_and_pack(
def create_test_tensors(
shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None,
group_size: int | None,
subset_stride_factor: int | None = None,
) -> Tensors:
m, n, k = shape
factor = subset_stride_factor or 1
@ -243,8 +242,8 @@ def create_test_tensors(
def machete_mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
group_size: int | None = None,
schedule: str | None = None,
):
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
output_ref_type = output_ref.dtype
@ -294,7 +293,7 @@ def machete_mm_test_helper(
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
group_sizes: list[int | None] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
@ -323,7 +322,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
group_sizes: list[int | None] = []
if types.group_scale_type is None:
group_sizes = [None]
else:

View File

@ -6,7 +6,6 @@ Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
"""
import importlib
from typing import Optional
import pytest
import torch
@ -27,7 +26,7 @@ def torch_scaled_mm(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out

View File

@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
from typing import Optional
import pytest
import torch
@ -38,8 +36,8 @@ def ref_int8_scaled_mm(
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
azp: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
azp: torch.Tensor | None,
bias: torch.Tensor | None,
output_type: torch.dtype,
):
if azp is not None:

View File

@ -7,7 +7,7 @@ import random
import unittest
from collections.abc import Sequence
from numbers import Number
from typing import Any, NamedTuple, Optional, Union
from typing import Any, NamedTuple
import pytest
import torch
@ -96,10 +96,10 @@ class PackedQKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[list[int]]
kv_start_loc_list: Optional[list[int]]
q_seq_lens: Optional[list[int]]
kv_seq_lens: Optional[list[int]]
q_start_loc_list: list[int] | None
kv_start_loc_list: list[int] | None
q_seq_lens: list[int] | None
kv_seq_lens: list[int] | None
class PackedQKVO(NamedTuple):
@ -115,7 +115,7 @@ class PackedQKVO(NamedTuple):
x head_size) known-correct attention output
"""
packed_qkv: Optional[PackedQKVInputs]
packed_qkv: PackedQKVInputs | None
ideal_output: torch.Tensor
@ -149,12 +149,12 @@ class PhaseTestParameters(NamedTuple):
"""
packed_qkvo: PackedQKVO
kv_mmap: Optional[KVMemoryMap]
kv_mmap: KVMemoryMap | None
def maybe_make_int_tensor(
_list: Optional[list[int]],
device: Union[torch.device, str],
_list: list[int] | None,
device: torch.device | str,
) -> torch.Tensor:
"""
Convert Python int list to a 1D int torch.Tensor on `device`
@ -170,8 +170,8 @@ def maybe_make_int_tensor(
def maybe_make_long_tensor(
_list: Optional[list[int]],
device: Union[torch.device, str],
_list: list[int] | None,
device: torch.device | str,
) -> torch.Tensor:
"""
Convert Python int list to a 1D long torch.Tensor on `device`
@ -186,7 +186,7 @@ def maybe_make_long_tensor(
)
def maybe_max(_list: Optional[list]) -> Optional[Number]:
def maybe_max(_list: list | None) -> Number | None:
"""
Returns:
@ -241,9 +241,9 @@ def ref_masked_attention(
key: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[list] = None,
kv_seq_lens: Optional[list] = None,
custom_mask: torch.Tensor | None = None,
q_seq_lens: list | None = None,
kv_seq_lens: list | None = None,
) -> torch.Tensor:
"""
"Golden" masked attention reference. Supports two types of masking:
@ -302,11 +302,11 @@ def ref_masked_attention(
def make_qkv(
batch_size: int,
max_q_seq_len: int,
max_kv_seq_len: Optional[int],
max_kv_seq_len: int | None,
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[list[int]] = None,
device: torch.device | str,
force_kv_seq_lens: list[int] | None = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
@ -436,7 +436,7 @@ def make_qkv(
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str]
unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str
) -> tuple[torch.Tensor, list[int]]:
"""
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
@ -470,7 +470,7 @@ def pack_tensor(
return packed_tensor, start_loc_list
def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs:
def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
"""
Individually pack each of Q, K and V, each with dimensions batch_size x
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
@ -594,19 +594,19 @@ def make_alibi_bias(
def _make_metadata_tensors(
seq_lens: Optional[list[int]],
context_lens: Optional[list[int]],
encoder_seq_lens: Optional[list[int]],
device: Union[torch.device, str],
seq_lens: list[int] | None,
context_lens: list[int] | None,
encoder_seq_lens: list[int] | None,
device: torch.device | str,
) -> tuple[
torch.Tensor,
torch.Tensor,
Any,
Any,
Optional[torch.Tensor],
torch.Tensor | None,
torch.Tensor,
torch.Tensor,
Optional[int],
int | None,
]:
"""
Build scalar & tensor values required to build attention metadata structure.
@ -678,7 +678,7 @@ def make_kv_cache(
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
device: torch.device | str,
backend: str,
default_val: float = 0.0,
) -> torch.Tensor:
@ -726,18 +726,18 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
return (num_tokens + block_size) // block_size
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
def make_empty_slot_mapping_tensor(device: torch.device | str):
return maybe_make_long_tensor([], device)
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
def make_empty_block_tables_tensor(device: torch.device | str):
return torch.tensor([], device=device)
def split_slot_mapping(
slot_mapping_list: torch.Tensor,
seq_lens: list[int],
device: Union[torch.device, str],
device: torch.device | str,
):
"""
Split a slot mapping into valid prefill- and decode-phase slot mappings.
@ -799,7 +799,7 @@ def split_slot_mapping(
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: list[int],
device: Union[torch.device, str],
device: torch.device | str,
block_base_addr: int = 0,
) -> tuple[torch.Tensor, list[int], int]:
"""
@ -880,11 +880,11 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[list[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
cross_test_params: Optional[PhaseTestParameters] = None,
seq_lens: list[int] | None,
decoder_test_params: PhaseTestParameters | None,
device: torch.device | str,
encoder_test_params: PhaseTestParameters | None = None,
cross_test_params: PhaseTestParameters | None = None,
) -> AttentionMetadata:
"""
Construct fake attention metadata for a given test phase
@ -1142,16 +1142,16 @@ def torch_experts(
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
b_bias1: torch.Tensor | None = None,
b_bias2: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
quant_dtype: torch.dtype | None = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
block_shape: list[int] | None = None,
apply_router_weights_on_input: bool = False,
) -> torch.Tensor:
assert (
@ -1261,10 +1261,10 @@ def torch_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
b_bias1: Optional[torch.Tensor] = None,
b_bias2: Optional[torch.Tensor] = None,
b_bias1: torch.Tensor | None = None,
b_bias2: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
@ -1298,15 +1298,13 @@ def torch_moe_single(a, w, score, topk):
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(
op: Union[
torch._ops.OpOverload,
torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef,
],
op: torch._ops.OpOverload
| torch._ops.OpOverloadPacket
| torch._library.custom_ops.CustomOpDef,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: dict[str, Any] | None = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True,
) -> dict[str, str]:
@ -1338,7 +1336,7 @@ def baseline_scaled_mm(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match