v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@ -43,6 +43,7 @@ import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
@ -90,7 +91,7 @@ Constraints for this example:
* Number of heads in Q must be divisible by number of heads in K
* mma_tiler_mn must be 128,128
* Batch size must be the same for Q, K, and V tensors
* For causal masking, use --has_casual_mask (note: specify without =True/False)
* For causal masking, use --is_causal (note: specify without =True/False)
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
"""
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
smem_copy_atom = sm100_utils.get_smem_store_op(
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
)
tiled_smem_store = cute.make_tiled_copy(
smem_copy_atom,
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
tiler_mn=tiled_tmem_load.tiler_mn,
)
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
return tile_sched_params, grid
def run_fmha_and_verify(
def run(
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
in_dtype: Type[cutlass.Numeric],
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
pv_acc_dtype: Type[cutlass.Numeric],
mma_tiler_mn: Tuple[int, int],
is_persistent: bool,
has_casual_mask: bool,
is_causal: bool,
scale_q: float,
scale_k: float,
scale_v: float,
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
warmup_iterations: int,
iterations: int,
skip_ref_check: bool,
use_cold_l2: bool = False,
**kwargs,
):
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
:type mma_tiler_mn: Tuple[int, int]
:param is_persistent: Whether to use persistent kernel optimization
:type is_persistent: bool
:param has_casual_mask: Whether to apply causal masking
:type has_casual_mask: bool
:param is_causal: Whether to apply causal masking
:type is_causal: bool
:param scale_q: Scaling factor for query tensor
:type scale_q: float
:param scale_k: Scaling factor for key tensor
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
:type iterations: int
:param skip_ref_check: Skip validation against reference implementation
:type skip_ref_check: bool
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
:type use_cold_l2: bool
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
:raises RuntimeError: If GPU is unavailable for computation
:return: Execution time of the FMHA kernel in microseconds
:rtype: float
"""
print(f"Running Blackwell SM100 FMHA test with:")
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
print(f" pv_acc_dtype: {pv_acc_dtype}")
print(f" mma_tiler_mn: {mma_tiler_mn}")
print(f" is_persistent: {is_persistent}")
print(f" has_casual_mask: {has_casual_mask}")
print(f" is_causal: {is_causal}")
print(f" scale_q: {scale_q}")
print(f" scale_k: {scale_k}")
print(f" scale_v: {scale_v}")
print(f" inv_scale_o: {inv_scale_o}")
print(f" scale_softmax: {scale_softmax}")
print(f" tolerance: {tolerance}")
print(f" warmup_iterations: {warmup_iterations}")
print(f" iterations: {iterations}")
print(f" skip_ref_check: {skip_ref_check}")
print(f" use_cold_l2: {use_cold_l2}")
# Unpack parameters
b, s_q, h_q, d = q_shape
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
mma_tiler = (*mma_tiler_mn, d)
mask_type = MaskType.NO_MASK
if has_casual_mask:
if is_causal:
mask_type = MaskType.CAUSAL_MASK
else:
if isinstance(s_k, tuple):
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
# Warmup
for _ in range(warmup_iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
# Execute kernel
for _ in range(iterations):
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
torch.cuda.synchronize()
def run_torch_fmha(
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
):
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
h_q = q.shape[2]
h_k = k.shape[2]
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
v = v.transpose(1, 2)
# For the situation that torch has not supported, we need to handle it manually
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
situation1 = is_causal and (q.is_nested or k.is_nested)
situation2 = (q.is_nested and not k.is_nested) or (
not q.is_nested and k.is_nested
)
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref_i = ref_i.transpose(0, 1) * scale_output
ref_list.append(ref_i)
if q.is_nested:
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
attn_mask=None,
dropout_p=0.0,
scale=scale_softmax,
is_causal=has_casual_mask,
is_causal=is_causal,
)
ref = ref.transpose(1, 2) * scale_output
ref = ref.transpose(1, 2) * scale_output
return ref
if not skip_ref_check:
# Execute kernel once for reference checking
compiled_fmha(
q_tensor.iterator,
k_tensor.iterator,
v_tensor.iterator,
o_tensor.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
print("Verifying results...")
o_ref = run_torch_fmha(
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
)
if o_ref.is_nested:
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
print("Results verified successfully!")
def generate_tensors():
_, q_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
in_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
_, k_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, v_tensor_workspace, _ = create_and_pad_tensor(
kv_shape,
kv_padding,
in_dtype,
s_cumsum=cum_seqlen_k_torch,
is_dynamic_layout=True,
)
_, o_tensor_workspace, _ = create_and_pad_tensor(
qo_shape,
qo_padding,
out_dtype,
s_cumsum=cum_seqlen_q_torch,
is_dynamic_layout=True,
)
return testing.JitArguments(
q_tensor_workspace.iterator,
k_tensor_workspace.iterator,
v_tensor_workspace.iterator,
o_tensor_workspace.iterator,
problem_size,
cum_seqlen_q,
cum_seqlen_k,
scale_softmax_log2,
scale_output,
current_stream,
)
workspace_count = 1
if use_cold_l2:
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
one_workspace_bytes = (
q_torch_effective.numel() * q_torch_effective.element_size()
+ k_torch_effective.numel() * k_torch_effective.element_size()
+ v_torch_effective.numel() * v_torch_effective.element_size()
+ o_torch_effective.numel() * o_torch_effective.element_size()
)
workspace_count = testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations
)
exec_time = testing.benchmark(
compiled_fmha,
workspace_generator=generate_tensors,
workspace_count=workspace_count,
stream=current_stream,
warmup_iterations=warmup_iterations,
iterations=iterations,
)
return exec_time # Return execution time in microseconds
if __name__ == "__main__":
def parse_comma_separated_ints(s: str):
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--has_casual_mask",
"--is_causal",
action="store_true",
help="Whether to use casual mask",
)
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
help="Skip reference check",
)
parser.add_argument(
"--use_cold_l2",
action="store_true",
default=False,
help="Use circular buffer tensor sets to ensure L2 cold cache",
)
args = parser.parse_args()
if len(args.q_shape) != 4:
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
torch.manual_seed(1111)
run_fmha_and_verify(
run(
args.q_shape,
args.k_shape,
args.in_dtype,
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
args.pv_acc_dtype,
args.mma_tiler_mn,
args.is_persistent,
args.has_casual_mask,
args.is_causal,
args.scale_q,
args.scale_k,
args.scale_v,
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.use_cold_l2,
)
print("PASS")