Compare commits
20 Commits
fix-precom
...
fp8_ep_dp
| Author | SHA1 | Date | |
|---|---|---|---|
| 1236aebf0e | |||
| 95c40f9b09 | |||
| a0efd3106c | |||
| e69879996f | |||
| 922165cba3 | |||
| 12ea698498 | |||
| caca0b718a | |||
| d86e3f0172 | |||
| 3ca8322b74 | |||
| 03b41b6cad | |||
| cad6447664 | |||
| c169b05541 | |||
| 468d16654a | |||
| 909f234faa | |||
| f8510587c2 | |||
| 9cfebf51ba | |||
| 77f95b99a6 | |||
| bbe888d033 | |||
| 25ed6738d4 | |||
| e568e401da |
@ -1,18 +1,38 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedPrepareAndFinalize, BatchedTritonExperts,
|
||||||
invoke_moe_batched_triton_kernel)
|
invoke_moe_batched_triton_kernel)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
TOP_KS = [1, 2, 6]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchedMMConfig:
|
class BatchedMMConfig:
|
||||||
dtype: torch.dtype
|
in_dtype: torch.dtype
|
||||||
|
out_dtype: torch.dtype
|
||||||
num_experts: int
|
num_experts: int
|
||||||
max_tokens_per_expert: int
|
max_tokens_per_expert: int
|
||||||
K: int
|
K: int
|
||||||
@ -28,17 +48,26 @@ class BatchedMMTensors:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_tensors(config: BatchedMMConfig):
|
def make_tensors(config: BatchedMMConfig):
|
||||||
|
if config.in_dtype == torch.torch.float8_e4m3fn:
|
||||||
|
config_in_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
config_in_dtype = config.in_dtype
|
||||||
|
|
||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype) / 10
|
dtype=config_in_dtype) / 10
|
||||||
B = torch.randn((config.num_experts, config.N, config.K),
|
B = torch.randn((config.num_experts, config.N, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config_in_dtype)
|
||||||
C = torch.zeros(
|
C = torch.zeros(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config.out_dtype)
|
||||||
|
|
||||||
|
A = A.to(config.in_dtype)
|
||||||
|
B = B.to(config.in_dtype)
|
||||||
|
|
||||||
num_expert_tokens = torch.randint(low=0,
|
num_expert_tokens = torch.randint(low=0,
|
||||||
high=config.max_tokens_per_expert,
|
high=config.max_tokens_per_expert,
|
||||||
size=(config.num_experts, ),
|
size=(config.num_experts, ),
|
||||||
@ -47,16 +76,96 @@ class BatchedMMTensors:
|
|||||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||||
|
|
||||||
|
|
||||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16):
|
||||||
|
"""This function performs matrix multiplication with block-wise
|
||||||
|
quantization using native torch.
|
||||||
|
It is agnostic to the input data type and can be used for both int8 and
|
||||||
|
fp8 data types.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||||
|
`Bs` (float32).
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
"""
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32).contiguous()
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
|
||||||
|
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
|
||||||
|
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
N, K = B.shape
|
||||||
|
origin_C_shape = A.shape[:-1] + (N, )
|
||||||
|
A = A.reshape(M, A.shape[-1])
|
||||||
|
As = As.reshape(M, As.shape[-1])
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
assert n_tiles == Bs.shape[0]
|
||||||
|
assert k_tiles == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = (M, N)
|
||||||
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
|
A_tiles = [
|
||||||
|
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
B_tiles = [[
|
||||||
|
B[
|
||||||
|
j * block_n:min((j + 1) * block_n, N),
|
||||||
|
i * block_k:min((i + 1) * block_k, K),
|
||||||
|
] for i in range(k_tiles)
|
||||||
|
] for j in range(n_tiles)]
|
||||||
|
C_tiles = [
|
||||||
|
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
a = A_tiles[i]
|
||||||
|
b = B_tiles[j][i]
|
||||||
|
c = C_tiles[j]
|
||||||
|
s = As_tiles[i] * Bs[j][i]
|
||||||
|
c[:, :] += torch.matmul(a, b.t()) * s
|
||||||
|
|
||||||
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def ref_impl(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
num_expert_tokens: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
B_scale: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||||
num_experts = num_expert_tokens.size(0)
|
num_experts = num_expert_tokens.size(0)
|
||||||
|
|
||||||
for e in range(num_experts):
|
for e in range(num_experts):
|
||||||
num_tokens = num_expert_tokens_cpu[e]
|
num_tokens = num_expert_tokens_cpu[e]
|
||||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
if A.dtype == torch.torch.float8_e4m3fn:
|
||||||
|
if False:
|
||||||
|
tmp = native_w8a8_block_matmul(A[e, :, :],
|
||||||
|
B[e].transpose(0, 1), A_scale,
|
||||||
|
B_scale, block_shape)
|
||||||
|
else:
|
||||||
|
tmp = ops.cutlass_scaled_mm(A[e, :, :], B[e].transpose(0, 1),
|
||||||
|
A_scale, B_scale, torch.bfloat16)
|
||||||
|
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||||
|
else:
|
||||||
|
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|
||||||
@ -66,22 +175,45 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
[32, 64, 128, 192, 224, 256, 512])
|
[32, 64, 128, 192, 224, 256, 512])
|
||||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize(
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
"dtype",
|
||||||
|
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
|
||||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
N: int, dtype: torch.dtype):
|
N: int, dtype: torch.dtype):
|
||||||
|
|
||||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
if dtype == torch.torch.float8_e4m3fn:
|
||||||
|
in_dtype = dtype
|
||||||
|
out_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
in_dtype = dtype
|
||||||
|
out_dtype = dtype
|
||||||
|
|
||||||
|
config = BatchedMMConfig(in_dtype, out_dtype, num_experts,
|
||||||
|
max_tokens_per_expert, K, N)
|
||||||
tensors = BatchedMMTensors.make_tensors(config)
|
tensors = BatchedMMTensors.make_tensors(config)
|
||||||
|
|
||||||
test_output = tensors.C
|
test_output = tensors.C
|
||||||
ref_output = test_output.clone()
|
ref_output = test_output.clone()
|
||||||
|
ref_output2 = test_output.clone()
|
||||||
|
|
||||||
compute_tl_dtype = {
|
compute_tl_dtype = {
|
||||||
torch.float16: tl.float16,
|
torch.float16: tl.float16,
|
||||||
torch.bfloat16: tl.bfloat16,
|
torch.bfloat16: tl.bfloat16,
|
||||||
torch.float32: tl.float32
|
torch.float32: tl.float32
|
||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
|
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||||
|
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||||
|
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||||
|
quant_block_shape = [1, 1]
|
||||||
|
else:
|
||||||
|
A_scale = None
|
||||||
|
B_scale = None
|
||||||
|
quant_block_shape = None
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(
|
invoke_moe_batched_triton_kernel(
|
||||||
tensors.A,
|
tensors.A,
|
||||||
tensors.B,
|
tensors.B,
|
||||||
@ -89,21 +221,30 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
tensors.num_expert_tokens,
|
tensors.num_expert_tokens,
|
||||||
compute_tl_dtype,
|
compute_tl_dtype,
|
||||||
# Quantization data
|
# Quantization data
|
||||||
None,
|
A_scale,
|
||||||
None,
|
B_scale,
|
||||||
None,
|
None,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
False,
|
use_fp8_w8a8,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
config={
|
config={
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": block_shape[0],
|
||||||
"BLOCK_SIZE_N": 16,
|
"BLOCK_SIZE_N": block_shape[1],
|
||||||
"BLOCK_SIZE_K": 16
|
"BLOCK_SIZE_K": block_shape[2],
|
||||||
})
|
},
|
||||||
|
block_shape=quant_block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
ref_output = ref_output.to(dtype=out_dtype)
|
||||||
tensors.num_expert_tokens)
|
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
||||||
|
tensors.B.to(dtype=out_dtype), ref_output,
|
||||||
|
tensors.num_expert_tokens, A_scale, B_scale,
|
||||||
|
block_shape[-2:])
|
||||||
|
|
||||||
|
ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2,
|
||||||
|
tensors.num_expert_tokens, A_scale, B_scale,
|
||||||
|
block_shape[-2:])
|
||||||
|
|
||||||
rtol, atol = {
|
rtol, atol = {
|
||||||
torch.float16: (6e-2, 6e-2),
|
torch.float16: (6e-2, 6e-2),
|
||||||
@ -111,4 +252,154 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
torch.float32: (1e-2, 1e-2),
|
torch.float32: (1e-2, 1e-2),
|
||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
|
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
def batched_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
|
BatchedPrepareAndFinalize(max_num_tokens,
|
||||||
|
world_size=1,
|
||||||
|
dp_size=1,
|
||||||
|
rank=0,
|
||||||
|
qtype=qtype,
|
||||||
|
block_shape=block_shape,
|
||||||
|
per_act_token=per_act_token),
|
||||||
|
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
|
dp_size=1,
|
||||||
|
world_size=1,
|
||||||
|
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||||
|
block_shape=block_shape))
|
||||||
|
|
||||||
|
return fused_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: same as torch_moe but with fused_topk factored out.
|
||||||
|
def torch_moe2(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M, K = a.shape
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
|
||||||
|
else:
|
||||||
|
a_scale = None
|
||||||
|
|
||||||
|
out = torch.zeros(M * topk,
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=a.device)
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
for i in range(num_experts):
|
||||||
|
mask = (topk_ids == i).view(-1)
|
||||||
|
if mask.sum():
|
||||||
|
if not use_fp8_w8a8:
|
||||||
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
|
else:
|
||||||
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
|
w1_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||||
|
|
||||||
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
|
w2_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
|
||||||
|
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||||
|
def test_fused_moe_batched_experts(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
block_shape = [128, 128]
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||||
|
qtype = dtype if dtype == torch.torch.float8_e4m3fn else None
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (2 * n + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
w1_s = torch.rand(
|
||||||
|
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
w2_s = torch.rand(
|
||||||
|
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
else:
|
||||||
|
w1_s = None
|
||||||
|
w2_s = None
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
|
w2_s, qtype, block_shape)
|
||||||
|
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
|
w2_s, use_fp8_w8a8, block_shape)
|
||||||
|
|
||||||
|
torch.testing.assert_close(baseline_output,
|
||||||
|
batched_output,
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=0)
|
||||||
|
|||||||
@ -33,7 +33,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
|||||||
get_default_config)
|
get_default_config)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||||
(222, 2048, 1024)]
|
(222, 2048, 1024)]
|
||||||
@ -74,6 +77,11 @@ class ProcessGroupInfo:
|
|||||||
device: torch.device
|
device: torch.device
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def use_pplx_backend(monkeypatch):
|
||||||
|
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "pplx")
|
||||||
|
|
||||||
|
|
||||||
def _worker_parallel_launch(
|
def _worker_parallel_launch(
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
@ -275,6 +283,70 @@ def batched_moe(
|
|||||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||||
|
|
||||||
|
|
||||||
|
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16):
|
||||||
|
"""This function performs matrix multiplication with block-wise
|
||||||
|
quantization using native torch.
|
||||||
|
It is agnostic to the input data type and can be used for both int8 and
|
||||||
|
fp8 data types.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||||
|
`Bs` (float32).
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
"""
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32).contiguous()
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
|
||||||
|
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
|
||||||
|
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
N, K = B.shape
|
||||||
|
origin_C_shape = A.shape[:-1] + (N, )
|
||||||
|
A = A.reshape(M, A.shape[-1])
|
||||||
|
As = As.reshape(M, As.shape[-1])
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
assert n_tiles == Bs.shape[0]
|
||||||
|
assert k_tiles == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = (M, N)
|
||||||
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
|
A_tiles = [
|
||||||
|
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
B_tiles = [[
|
||||||
|
B[
|
||||||
|
j * block_n:min((j + 1) * block_n, N),
|
||||||
|
i * block_k:min((i + 1) * block_k, K),
|
||||||
|
] for i in range(k_tiles)
|
||||||
|
] for j in range(n_tiles)]
|
||||||
|
C_tiles = [
|
||||||
|
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
a = A_tiles[i]
|
||||||
|
b = B_tiles[j][i]
|
||||||
|
c = C_tiles[j]
|
||||||
|
s = As_tiles[i] * Bs[j][i]
|
||||||
|
c[:, :] += torch.matmul(a, b.t()) * s
|
||||||
|
|
||||||
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
# Note: same as torch_moe but with fused_topk factored out.
|
# Note: same as torch_moe but with fused_topk factored out.
|
||||||
def torch_moe2(
|
def torch_moe2(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -282,17 +354,44 @@ def torch_moe2(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
M, K = a.shape
|
M, K = a.shape
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
|
||||||
|
else:
|
||||||
|
a_scale = None
|
||||||
|
|
||||||
|
out = torch.zeros(M * topk,
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=a.device)
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
mask = (topk_ids == i).view(-1)
|
mask = (topk_ids == i).view(-1)
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
out[mask] = SiluAndMul()(
|
if not use_fp8_w8a8:
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
|
else:
|
||||||
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
|
w1_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||||
|
|
||||||
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
|
w2_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
return (out.view(M, -1, w2.shape[1]) *
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
@ -497,6 +596,10 @@ def pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
use_compile: bool = True,
|
use_compile: bool = True,
|
||||||
use_cudagraphs: bool = True,
|
use_cudagraphs: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -506,9 +609,17 @@ def pplx_moe(
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
hidden_dim = a.shape[1]
|
hidden_dim = a.shape[1]
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
block_size = 128
|
block_size = block_shape[1] if block_shape is not None else 128
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||||
|
|
||||||
|
if qtype is not None:
|
||||||
|
a_dtype = qtype
|
||||||
|
# This is probably not right
|
||||||
|
scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16)
|
||||||
|
else:
|
||||||
|
a_dtype = a.dtype
|
||||||
|
scale_bytes = 0
|
||||||
|
|
||||||
ata = AllToAll.internode(
|
ata = AllToAll.internode(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
@ -518,10 +629,8 @@ def pplx_moe(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
hidden_dim_bytes=hidden_dim * a_dtype.itemsize,
|
||||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
hidden_dim_scale_bytes=scale_bytes,
|
||||||
((hidden_dim + block_size - 1) // block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||||
@ -532,11 +641,15 @@ def pplx_moe(
|
|||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
dp_size,
|
dp_size,
|
||||||
|
quant_dtype=qtype,
|
||||||
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size)
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@ -552,6 +665,13 @@ def pplx_moe(
|
|||||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||||
|
|
||||||
|
if w1_scale is not None:
|
||||||
|
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||||
|
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||||
|
else:
|
||||||
|
w1_scale_chunk = None
|
||||||
|
w2_scale_chunk = None
|
||||||
|
|
||||||
if use_compile:
|
if use_compile:
|
||||||
_fused_experts = torch.compile(fused_experts,
|
_fused_experts = torch.compile(fused_experts,
|
||||||
backend='inductor',
|
backend='inductor',
|
||||||
@ -564,6 +684,8 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
if use_cudagraphs:
|
if use_cudagraphs:
|
||||||
@ -576,6 +698,8 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -638,6 +762,10 @@ def _pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
score: torch.Tensor,
|
score: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
w1_s: Optional[torch.Tensor] = None,
|
||||||
|
w2_s: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
uid = nvshmem_get_unique_id(
|
uid = nvshmem_get_unique_id(
|
||||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
@ -649,11 +777,20 @@ def _pplx_moe(
|
|||||||
|
|
||||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||||
|
|
||||||
|
use_fp8_w8a8 = qtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
device = torch.device("cuda", pgi.rank)
|
||||||
|
a = a.to(device)
|
||||||
|
w1 = w1.to(device)
|
||||||
|
w2 = w2.to(device)
|
||||||
|
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||||
|
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||||
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
||||||
topk_weight, topk_ids)
|
topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape)
|
||||||
# TODO (bnell): fix + re-enable
|
# TODO (bnell): fix + re-enable
|
||||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||||
# topk_ids)
|
# topk_ids)
|
||||||
@ -670,7 +807,7 @@ def _pplx_moe(
|
|||||||
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||||
@requires_pplx
|
@requires_pplx
|
||||||
def test_pplx_moe(
|
def test_pplx_moe(
|
||||||
@ -683,9 +820,40 @@ def test_pplx_moe(
|
|||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
world_size, dp_size = world_dp_size
|
world_size, dp_size = world_dp_size
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
block_shape = [128, 128]
|
||||||
|
quant_type = torch.float8_e4m3fn
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (2 * n + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
w1_s = torch.rand(
|
||||||
|
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
w2_s = torch.rand(
|
||||||
|
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
else:
|
||||||
|
block_shape = None
|
||||||
|
quant_type = None
|
||||||
|
w1_s = None
|
||||||
|
w2_s = None
|
||||||
|
|
||||||
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_type, block_shape)
|
||||||
|
|||||||
@ -83,6 +83,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||||
super().__init__(cpu_group)
|
super().__init__(cpu_group)
|
||||||
|
|
||||||
|
# Intranode doesn't work yet.
|
||||||
|
self.internode = True
|
||||||
|
|
||||||
if self.internode:
|
if self.internode:
|
||||||
# inter-node communication needs nvshmem,
|
# inter-node communication needs nvshmem,
|
||||||
# intra-node communication uses p2p mapping directly
|
# intra-node communication uses p2p mapping directly
|
||||||
|
|||||||
@ -4,7 +4,8 @@ from contextlib import contextmanager
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
_config: Optional[dict[str, Any]] = None
|
_config: Optional[dict[str, Any]] = None
|
||||||
@ -29,6 +30,7 @@ __all__ = [
|
|||||||
"FusedMoeWeightScaleSupported",
|
"FusedMoeWeightScaleSupported",
|
||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
|
"MOE_DP_CHUNK_SIZE",
|
||||||
]
|
]
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|||||||
@ -9,7 +9,9 @@ import triton.language as tl
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
get_config_dtype_str, try_get_optimal_moe_config)
|
get_config_dtype_str, try_get_optimal_moe_config)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
_resize_cache, moe_kernel_quantize_input)
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -315,8 +317,8 @@ def invoke_moe_batched_triton_kernel(
|
|||||||
expert_num_tokens: torch.Tensor, # [E]
|
expert_num_tokens: torch.Tensor, # [E]
|
||||||
compute_type: tl.dtype,
|
compute_type: tl.dtype,
|
||||||
# Quantization data
|
# Quantization data
|
||||||
A_scale: torch.Tensor,
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: torch.Tensor,
|
B_scale: Optional[torch.Tensor],
|
||||||
B_zp: torch.Tensor,
|
B_zp: torch.Tensor,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
@ -335,7 +337,7 @@ def invoke_moe_batched_triton_kernel(
|
|||||||
BLOCK_K = config['BLOCK_SIZE_K']
|
BLOCK_K = config['BLOCK_SIZE_K']
|
||||||
assert (torch.compiler.is_compiling()
|
assert (torch.compiler.is_compiling()
|
||||||
or torch.cuda.is_current_stream_capturing()
|
or torch.cuda.is_current_stream_capturing()
|
||||||
or max_num_tokens % BLOCK_M == 0)
|
or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}"
|
||||||
|
|
||||||
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
||||||
triton.cdiv(B.size(1), BLOCK_N))
|
triton.cdiv(B.size(1), BLOCK_N))
|
||||||
@ -388,13 +390,22 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
that the PPLX dispatch/combine kernels use.
|
that the PPLX dispatch/combine kernels use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
def __init__(self,
|
||||||
dp_size: int, rank: int):
|
max_num_tokens: Optional[int],
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
rank: int,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.per_act_token = per_act_token
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.qtype = qtype
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
@ -436,20 +447,47 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
b_a1 = torch.zeros(
|
b_a1 = torch.zeros(
|
||||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||||
dtype=a1.dtype,
|
dtype=self.qtype if self.qtype is not None else a1.dtype,
|
||||||
device=a1.device)
|
device=a1.device)
|
||||||
|
|
||||||
|
if self.qtype is not None:
|
||||||
|
_, block_k = self.block_shape
|
||||||
|
k_tiles = (hidden_dim + block_k - 1) // block_k
|
||||||
|
b_a1_scale = torch.zeros(
|
||||||
|
(num_local_experts, self.max_num_tokens, k_tiles),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=a1.device)
|
||||||
|
else:
|
||||||
|
assert a1_scale is None
|
||||||
|
b_a1_scale = None
|
||||||
|
|
||||||
first_expert = num_local_experts * self.rank
|
first_expert = num_local_experts * self.rank
|
||||||
last_expert = first_expert + num_local_experts
|
last_expert = first_expert + num_local_experts
|
||||||
|
|
||||||
for expert_id in range(first_expert, last_expert):
|
for expert_id in range(first_expert, last_expert):
|
||||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||||
rows = torch.count_nonzero(topks.flatten())
|
rows = torch.count_nonzero(topks.flatten())
|
||||||
b_a1[expert_id -
|
rhs = a1[:topks.numel()][topks]
|
||||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
idx = expert_id - first_expert
|
||||||
tokens_per_expert[expert_id - first_expert] = rows
|
if self.qtype is not None:
|
||||||
|
if a1_scale is not None:
|
||||||
|
rhs_a1_scale = a1_scale[:topks.numel()][topks]
|
||||||
|
else:
|
||||||
|
rhs_a1_scale = None
|
||||||
|
b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = (
|
||||||
|
moe_kernel_quantize_input(
|
||||||
|
rhs,
|
||||||
|
rhs_a1_scale,
|
||||||
|
self.qtype,
|
||||||
|
self.per_act_token,
|
||||||
|
self.block_shape,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
b_a1[idx, :rows, :] = rhs
|
||||||
|
|
||||||
return b_a1, a1_scale, tokens_per_expert
|
tokens_per_expert[idx] = rows
|
||||||
|
|
||||||
|
return b_a1, b_a1_scale, tokens_per_expert
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self,
|
self,
|
||||||
@ -499,15 +537,15 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
block_m: Optional[int] = None,
|
block_m: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert block_shape is None
|
|
||||||
assert block_m is None
|
assert block_m is None
|
||||||
assert not use_fp8_w8a8, "NYI"
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
assert not use_int8_w8a8, "NYI"
|
||||||
assert not use_int8_w8a16, "NYI"
|
assert not use_int8_w8a16, "NYI"
|
||||||
assert not use_int4_w4a16, "NYI"
|
assert not use_int4_w4a16, "NYI"
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
self.block_shape = block_shape
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -522,7 +560,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
num_dp = self.world_size // self.dp_size
|
num_dp = self.world_size // self.dp_size
|
||||||
max_num_tokens = a.size(
|
max_num_tokens = a.size(
|
||||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
|
|
||||||
workspace13 = num_experts * max_num_tokens * num_dp * K
|
workspace13 = num_experts * max_num_tokens * num_dp * K
|
||||||
workspace2 = max_num_tokens * num_dp * N
|
workspace2 = max_num_tokens * num_dp * N
|
||||||
return (workspace13, workspace2, a.dtype)
|
return (workspace13, workspace2, a.dtype)
|
||||||
@ -579,6 +616,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
else:
|
else:
|
||||||
num = int(expert_num_tokens[expert].item())
|
num = int(expert_num_tokens[expert].item())
|
||||||
tmp = _resize_cache(workspace2, (num, N))
|
tmp = _resize_cache(workspace2, (num, N))
|
||||||
|
assert not self.use_fp8_w8a8
|
||||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||||
self.activation(activation, tmp, input)
|
self.activation(activation, tmp, input)
|
||||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||||
@ -586,6 +624,61 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def batched_moe_kernel_quantize_input(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
num_tokens: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
expert_num_tokens: torch.Tensor,
|
||||||
|
qtype: Optional[torch.dtype],
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if (True or
|
||||||
|
torch.compiler.is_compiling()
|
||||||
|
or torch.cuda.is_current_stream_capturing()):
|
||||||
|
# Note: this does a bunch of extra work because expert_num_tokens is ignored
|
||||||
|
# but it does support torch.compile + cudagraphs.
|
||||||
|
hidden_dim = A.size(-1)
|
||||||
|
if block_shape is not None:
|
||||||
|
block_shape = [block_shape[1], block_shape[0]]
|
||||||
|
assert A_scale is None or A_scale.dim() == 2
|
||||||
|
A_q, A_q_scale = moe_kernel_quantize_input(
|
||||||
|
A.view(-1, hidden_dim),
|
||||||
|
A_scale,
|
||||||
|
qtype,
|
||||||
|
per_channel_quant,
|
||||||
|
block_shape)
|
||||||
|
A_q = A_q.view(E, -1, hidden_dim)
|
||||||
|
if A_q_scale is not None:
|
||||||
|
A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1))
|
||||||
|
return A_q, A_q_scale
|
||||||
|
|
||||||
|
|
||||||
|
if qtype is not None:
|
||||||
|
assert block_shape is not None
|
||||||
|
A_q = torch.empty_like(A, dtype=qtype)
|
||||||
|
block_n, block_k = block_shape
|
||||||
|
n_tiles = ((N // 2) + block_n - 1) // block_n
|
||||||
|
scale_shape = (E, num_tokens, n_tiles)
|
||||||
|
A_q_scale = torch.empty(scale_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=A.device)
|
||||||
|
for e in range(E):
|
||||||
|
num_tokens = expert_num_tokens[e]
|
||||||
|
if num_tokens > 0:
|
||||||
|
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
|
||||||
|
A[e, :num_tokens],
|
||||||
|
A_scale[e, :num_tokens] if A_scale else None, qtype,
|
||||||
|
per_channel_quant, [block_k, block_n])
|
||||||
|
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
||||||
|
|
||||||
|
return A_q, A_q_scale
|
||||||
|
else:
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
"""
|
"""
|
||||||
A Triton based MoE expert class that operates on expert batched format,
|
A Triton based MoE expert class that operates on expert batched format,
|
||||||
@ -595,12 +688,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_tokens: Optional[int] = None,
|
max_num_tokens: int,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
dp_size: int = 1,
|
dp_size: int = 1,
|
||||||
):
|
):
|
||||||
@ -610,11 +704,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.use_int4_w4a16 = use_int4_w4a16
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
self.block_shape = block_shape
|
self.block_shape = block_shape
|
||||||
self.max_num_tokens = max_num_tokens
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
assert not use_int8_w8a8, "NYI"
|
||||||
assert not use_int4_w4a16, "NYI"
|
assert not use_int4_w4a16, "NYI"
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
self.per_act_token = per_act_token
|
||||||
|
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -627,10 +723,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
) -> tuple[int, int, torch.dtype]:
|
) -> tuple[int, int, torch.dtype]:
|
||||||
assert a.dim() == 2
|
assert a.dim() == 2
|
||||||
num_dp = self.world_size // self.dp_size
|
num_dp = self.world_size // self.dp_size
|
||||||
max_num_tokens = a.size(
|
workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N)
|
||||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2)
|
||||||
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
|
||||||
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
|
||||||
return (workspace13, workspace2, a.dtype)
|
return (workspace13, workspace2, a.dtype)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -702,7 +796,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
|
||||||
# We can reuse the memory between these because by the time we need
|
# We can reuse the memory between these because by the time we need
|
||||||
# cache3, we're done with cache1
|
# cache3, we're done with cache1
|
||||||
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
||||||
@ -730,15 +823,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
#qintermediate_cache2 = intermediate_cache2
|
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||||
a2q_scale = a2_scale
|
intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
|
||||||
# TODO (varun) : support w8a8
|
self.qtype, self.per_act_token, self.block_shape)
|
||||||
assert not self.use_fp8_w8a8
|
|
||||||
#if self.use_fp8_w8a8:
|
|
||||||
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
|
||||||
# intermediate_cache2, a2_scale, self.block_shape)
|
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||||
B=w2,
|
B=w2,
|
||||||
C=intermediate_cache3,
|
C=intermediate_cache3,
|
||||||
expert_num_tokens=expert_num_tokens,
|
expert_num_tokens=expert_num_tokens,
|
||||||
|
|||||||
@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool = False,
|
||||||
per_channel_quant: bool,
|
per_channel_quant: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
block_m: Optional[int] = None,
|
block_m: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -8,6 +8,9 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -26,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op, cdiv
|
||||||
|
|
||||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||||
|
|
||||||
@ -56,7 +59,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
# Note: this limit is somewhat arbitrary and might be changed later.
|
# Note: this limit is somewhat arbitrary and might be changed later.
|
||||||
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
||||||
MOE_DP_CHUNK_SIZE = 256
|
MOE_DP_CHUNK_SIZE = 128
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -72,7 +75,7 @@ class FusedMoEParallelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_pplx_kernels(self):
|
def use_pplx_kernels(self):
|
||||||
return self.dp_size > 1 and self.use_ep and \
|
return self.dp_size > 1 and self.use_ep and has_pplx and \
|
||||||
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -191,7 +194,8 @@ class MoEConfig:
|
|||||||
num_local_experts: int
|
num_local_experts: int
|
||||||
moe_parallel_config: FusedMoEParallelConfig
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
in_dtype: torch.dtype # The activation type.
|
in_dtype: torch.dtype # The post quantization activation type.
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
|
|
||||||
# TODO: add more quantization params, blocked, per-token, etc.
|
# TODO: add more quantization params, blocked, per-token, etc.
|
||||||
block_size: int = 128
|
block_size: int = 128
|
||||||
@ -238,6 +242,18 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_config_input_activations(
|
||||||
|
quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||||
|
and "Linear" in quant_config.target_scheme_map and
|
||||||
|
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||||
|
return quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -253,6 +269,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
prepare_finalize = None
|
prepare_finalize = None
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
|
# For blocked per token: set to
|
||||||
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
|
# For per-token: set to sizeof(float32)
|
||||||
|
if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1:
|
||||||
|
hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize
|
||||||
|
hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
||||||
|
torch.float32.itemsize)
|
||||||
|
else:
|
||||||
|
hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize
|
||||||
|
hidden_scale_bytes = 0
|
||||||
|
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
num_experts=moe.num_experts,
|
num_experts=moe.num_experts,
|
||||||
@ -262,18 +289,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
hidden_dim=moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
hidden_dim_bytes=hidden_dim_bytes,
|
||||||
# For blocked per token: set to
|
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
|
||||||
# For per-token: set to sizeof(float32)
|
|
||||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
|
|
||||||
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
group_name=all2all_manager.cpu_group.group_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not all2all_manager.internode:
|
||||||
|
all_to_all_args["group_name"] = \
|
||||||
|
all2all_manager.cpu_group.group_name
|
||||||
|
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
|
logger.debug("PplxPrepareAndFinalize")
|
||||||
prepare_finalize = PplxPrepareAndFinalize(
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
handle,
|
handle,
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
@ -281,7 +307,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
rank=all2all_manager.rank,
|
rank=all2all_manager.rank,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
quant_dtype=moe.in_dtype,
|
quant_dtype=moe.quant_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
@ -346,33 +372,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
|
||||||
|
|
||||||
if isinstance(prepare_finalize,
|
if isinstance(prepare_finalize,
|
||||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
experts = BatchedTritonExperts(
|
return BatchedTritonExperts(
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
world_size=all2all_manager.world_size,
|
world_size=all2all_manager.world_size,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
use_fp8_w8a8=False,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonExperts %s", self.moe)
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
experts = TritonExperts(
|
return TritonExperts()
|
||||||
use_fp8_w8a8=False,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
per_channel_quant=False,
|
|
||||||
)
|
|
||||||
return experts
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -785,14 +796,32 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
|
logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype)
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
|
if quant_config is not None:
|
||||||
|
input_activations = get_quant_config_input_activations(
|
||||||
|
quant_config)
|
||||||
|
if (input_activations is not None
|
||||||
|
and input_activations.num_bits == 8):
|
||||||
|
if input_activations.type == QuantizationType.FLOAT:
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
elif input_activations.type == QuantizationType.INT:
|
||||||
|
quant_dtype = torch.int8
|
||||||
|
|
||||||
|
# Total hack
|
||||||
|
if quant_config.__class__.__name__ == "Fp8Config":
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
logger.info("QUANT_DTYPE %s", quant_dtype)
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
# TODO (bnell): this needs to be fixed for quantized types.
|
in_dtype=vllm_config.model_config.dtype,
|
||||||
in_dtype=params_dtype,
|
quant_dtype=quant_dtype,
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
@ -832,15 +861,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
if self.moe_parallel_config.use_pplx_kernels:
|
if self.moe_parallel_config.use_pplx_kernels:
|
||||||
act_dtype = vllm_config.model_config.dtype
|
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||||
dtype=act_dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||||
dtype=act_dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1251,7 +1279,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
assert (self.batched_router_logits.size(0) # type: ignore
|
assert (self.batched_router_logits.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
staged_hidden_states = self.batched_hidden_states[:
|
staged_hidden_states = self.batched_hidden_states[:
|
||||||
chunk_size, :] # type: ignore
|
chunk_size, :] # type: ignore
|
||||||
|
|||||||
@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
per_act_token,
|
per_act_token,
|
||||||
self.block_shape)
|
self.block_shape)
|
||||||
|
|
||||||
|
if a1q_scale is not None and a1q_scale.dim() == 1:
|
||||||
|
assert a1q_scale.numel() == 1
|
||||||
|
a1q_scale = a1q_scale.view(1, 1)
|
||||||
|
|
||||||
# rem_experts need to be 0 for pplx to work properly.
|
# rem_experts need to be 0 for pplx to work properly.
|
||||||
rem_experts = num_experts % self.world_size
|
rem_experts = num_experts % self.world_size
|
||||||
assert rem_experts == 0
|
assert rem_experts == 0
|
||||||
@ -90,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
float32_size = torch.float32.itemsize
|
float32_size = torch.float32.itemsize
|
||||||
block_size = (self.block_shape[0] if self.block_shape is not None
|
block_size = (self.block_shape[0] if self.block_shape is not None
|
||||||
else 1) * float32_size
|
else 1) * float32_size
|
||||||
expert_x_scale = torch.empty(
|
expert_x_scale = torch.zeros(
|
||||||
(
|
(
|
||||||
num_experts,
|
num_experts,
|
||||||
expert_x.size(1),
|
expert_x.size(1),
|
||||||
|
|||||||
@ -11,9 +11,10 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
|
||||||
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
self.fused_experts = functools.partial( # type: ignore
|
self.fused_experts = functools.partial( # type: ignore
|
||||||
fused_experts,
|
fused_experts,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
allow_deep_gemm=self.allow_deep_gemm)
|
allow_deep_gemm=self.allow_deep_gemm)
|
||||||
|
|
||||||
|
self.use_pplx_kernels = False
|
||||||
|
self.rocm_aiter_moe_enabled = False
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -764,19 +769,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize):
|
def select_gemm_impl(self, prepare_finalize):
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
TritonOrDeepGemmExperts)
|
TritonOrDeepGemmExperts)
|
||||||
|
|
||||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||||
|
|
||||||
experts = TritonOrDeepGemmExperts(
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
use_fp8_w8a8=True,
|
assert all2all_manager is not None
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
return experts
|
if isinstance(prepare_finalize,
|
||||||
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
|
logger.debug("BatchedTritonExperts(fp8)")
|
||||||
|
self.use_pplx_kernels = True
|
||||||
|
return BatchedTritonExperts(
|
||||||
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
|
world_size=all2all_manager.world_size,
|
||||||
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
per_act_token=False, #?
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("TritonOrDeepGemmExperts(fp8)")
|
||||||
|
return TritonOrDeepGemmExperts(
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -807,7 +831,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
indices_type=torch.uint32 if self.use_pplx_kernels else None)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
@ -854,7 +878,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
|
|||||||
Reference in New Issue
Block a user