v4.1 release update v2. (#2481)
This commit is contained in:
@ -43,6 +43,7 @@ import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Int32, Int64, Float32, Boolean
|
||||
|
||||
@ -90,7 +91,7 @@ Constraints for this example:
|
||||
* Number of heads in Q must be divisible by number of heads in K
|
||||
* mma_tiler_mn must be 128,128
|
||||
* Batch size must be the same for Q, K, and V tensors
|
||||
* For causal masking, use --has_casual_mask (note: specify without =True/False)
|
||||
* For causal masking, use --is_causal (note: specify without =True/False)
|
||||
* For persistent scheduling, use --is_persistent (note: specify without =True/False)
|
||||
"""
|
||||
|
||||
@ -2373,11 +2374,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
smem_copy_atom = sm100_utils.get_smem_store_op(
|
||||
self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy(
|
||||
smem_copy_atom,
|
||||
layout_tv=tiled_tmem_load.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_tmem_load.tiler_mn,
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
||||
|
||||
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
||||
@ -2619,7 +2616,7 @@ class BlackwellFusedMultiHeadAttentionForward:
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
def run_fmha_and_verify(
|
||||
def run(
|
||||
q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int],
|
||||
in_dtype: Type[cutlass.Numeric],
|
||||
@ -2628,7 +2625,7 @@ def run_fmha_and_verify(
|
||||
pv_acc_dtype: Type[cutlass.Numeric],
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
is_persistent: bool,
|
||||
has_casual_mask: bool,
|
||||
is_causal: bool,
|
||||
scale_q: float,
|
||||
scale_k: float,
|
||||
scale_v: float,
|
||||
@ -2638,6 +2635,8 @@ def run_fmha_and_verify(
|
||||
warmup_iterations: int,
|
||||
iterations: int,
|
||||
skip_ref_check: bool,
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results.
|
||||
|
||||
@ -2670,8 +2669,8 @@ def run_fmha_and_verify(
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param is_persistent: Whether to use persistent kernel optimization
|
||||
:type is_persistent: bool
|
||||
:param has_casual_mask: Whether to apply causal masking
|
||||
:type has_casual_mask: bool
|
||||
:param is_causal: Whether to apply causal masking
|
||||
:type is_causal: bool
|
||||
:param scale_q: Scaling factor for query tensor
|
||||
:type scale_q: float
|
||||
:param scale_k: Scaling factor for key tensor
|
||||
@ -2690,9 +2689,13 @@ def run_fmha_and_verify(
|
||||
:type iterations: int
|
||||
:param skip_ref_check: Skip validation against reference implementation
|
||||
:type skip_ref_check: bool
|
||||
:param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache
|
||||
:type use_cold_l2: bool
|
||||
|
||||
:raises ValueError: If input shapes are incompatible or head dimension is unsupported
|
||||
:raises RuntimeError: If GPU is unavailable for computation
|
||||
:return: Execution time of the FMHA kernel in microseconds
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
print(f"Running Blackwell SM100 FMHA test with:")
|
||||
@ -2704,13 +2707,17 @@ def run_fmha_and_verify(
|
||||
print(f" pv_acc_dtype: {pv_acc_dtype}")
|
||||
print(f" mma_tiler_mn: {mma_tiler_mn}")
|
||||
print(f" is_persistent: {is_persistent}")
|
||||
print(f" has_casual_mask: {has_casual_mask}")
|
||||
print(f" is_causal: {is_causal}")
|
||||
print(f" scale_q: {scale_q}")
|
||||
print(f" scale_k: {scale_k}")
|
||||
print(f" scale_v: {scale_v}")
|
||||
print(f" inv_scale_o: {inv_scale_o}")
|
||||
print(f" scale_softmax: {scale_softmax}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print(f" warmup_iterations: {warmup_iterations}")
|
||||
print(f" iterations: {iterations}")
|
||||
print(f" skip_ref_check: {skip_ref_check}")
|
||||
print(f" use_cold_l2: {use_cold_l2}")
|
||||
|
||||
# Unpack parameters
|
||||
b, s_q, h_q, d = q_shape
|
||||
@ -2882,7 +2889,7 @@ def run_fmha_and_verify(
|
||||
mma_tiler = (*mma_tiler_mn, d)
|
||||
|
||||
mask_type = MaskType.NO_MASK
|
||||
if has_casual_mask:
|
||||
if is_causal:
|
||||
mask_type = MaskType.CAUSAL_MASK
|
||||
else:
|
||||
if isinstance(s_k, tuple):
|
||||
@ -2942,41 +2949,7 @@ def run_fmha_and_verify(
|
||||
compilation_time = time.time() - start_time
|
||||
print(f"Compilation time: {compilation_time:.4f} seconds")
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup_iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
# Execute kernel
|
||||
for _ in range(iterations):
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_torch_fmha(
|
||||
q, k, v, scale_softmax=1.0, scale_output=1.0, has_casual_mask=False
|
||||
):
|
||||
def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False):
|
||||
h_q = q.shape[2]
|
||||
h_k = k.shape[2]
|
||||
|
||||
@ -3005,7 +2978,7 @@ def run_fmha_and_verify(
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# For the situation that torch has not supported, we need to handle it manually
|
||||
situation1 = has_casual_mask and (q.is_nested or k.is_nested)
|
||||
situation1 = is_causal and (q.is_nested or k.is_nested)
|
||||
situation2 = (q.is_nested and not k.is_nested) or (
|
||||
not q.is_nested and k.is_nested
|
||||
)
|
||||
@ -3025,8 +2998,9 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref_i = ref_i.transpose(0, 1) * scale_output
|
||||
ref_list.append(ref_i)
|
||||
if q.is_nested:
|
||||
ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged)
|
||||
@ -3040,15 +3014,28 @@ def run_fmha_and_verify(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
scale=scale_softmax,
|
||||
is_causal=has_casual_mask,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
ref = ref.transpose(1, 2) * scale_output
|
||||
return ref
|
||||
|
||||
if not skip_ref_check:
|
||||
# Execute kernel once for reference checking
|
||||
compiled_fmha(
|
||||
q_tensor.iterator,
|
||||
k_tensor.iterator,
|
||||
v_tensor.iterator,
|
||||
o_tensor.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
print("Verifying results...")
|
||||
o_ref = run_torch_fmha(
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, has_casual_mask
|
||||
q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal
|
||||
)
|
||||
|
||||
if o_ref.is_nested:
|
||||
@ -3095,6 +3082,76 @@ def run_fmha_and_verify(
|
||||
torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05)
|
||||
print("Results verified successfully!")
|
||||
|
||||
def generate_tensors():
|
||||
_, q_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, k_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, v_tensor_workspace, _ = create_and_pad_tensor(
|
||||
kv_shape,
|
||||
kv_padding,
|
||||
in_dtype,
|
||||
s_cumsum=cum_seqlen_k_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
_, o_tensor_workspace, _ = create_and_pad_tensor(
|
||||
qo_shape,
|
||||
qo_padding,
|
||||
out_dtype,
|
||||
s_cumsum=cum_seqlen_q_torch,
|
||||
is_dynamic_layout=True,
|
||||
)
|
||||
return testing.JitArguments(
|
||||
q_tensor_workspace.iterator,
|
||||
k_tensor_workspace.iterator,
|
||||
v_tensor_workspace.iterator,
|
||||
o_tensor_workspace.iterator,
|
||||
problem_size,
|
||||
cum_seqlen_q,
|
||||
cum_seqlen_k,
|
||||
scale_softmax_log2,
|
||||
scale_output,
|
||||
current_stream,
|
||||
)
|
||||
|
||||
workspace_count = 1
|
||||
if use_cold_l2:
|
||||
q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch
|
||||
k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch
|
||||
v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch
|
||||
o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch
|
||||
one_workspace_bytes = (
|
||||
q_torch_effective.numel() * q_torch_effective.element_size()
|
||||
+ k_torch_effective.numel() * k_torch_effective.element_size()
|
||||
+ v_torch_effective.numel() * v_torch_effective.element_size()
|
||||
+ o_torch_effective.numel() * o_torch_effective.element_size()
|
||||
)
|
||||
workspace_count = testing.get_workspace_count(
|
||||
one_workspace_bytes, warmup_iterations, iterations
|
||||
)
|
||||
|
||||
exec_time = testing.benchmark(
|
||||
compiled_fmha,
|
||||
workspace_generator=generate_tensors,
|
||||
workspace_count=workspace_count,
|
||||
stream=current_stream,
|
||||
warmup_iterations=warmup_iterations,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
@ -3185,7 +3242,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--has_casual_mask",
|
||||
"--is_causal",
|
||||
action="store_true",
|
||||
help="Whether to use casual mask",
|
||||
)
|
||||
@ -3263,6 +3320,13 @@ if __name__ == "__main__":
|
||||
help="Skip reference check",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cold_l2",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use circular buffer tensor sets to ensure L2 cold cache",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.q_shape) != 4:
|
||||
@ -3279,7 +3343,7 @@ if __name__ == "__main__":
|
||||
|
||||
torch.manual_seed(1111)
|
||||
|
||||
run_fmha_and_verify(
|
||||
run(
|
||||
args.q_shape,
|
||||
args.k_shape,
|
||||
args.in_dtype,
|
||||
@ -3288,7 +3352,7 @@ if __name__ == "__main__":
|
||||
args.pv_acc_dtype,
|
||||
args.mma_tiler_mn,
|
||||
args.is_persistent,
|
||||
args.has_casual_mask,
|
||||
args.is_causal,
|
||||
args.scale_q,
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
@ -3298,6 +3362,7 @@ if __name__ == "__main__":
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.use_cold_l2,
|
||||
)
|
||||
|
||||
print("PASS")
|
||||
|
||||
Reference in New Issue
Block a user