v4.1 release
This commit is contained in:
3619
examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py
Normal file
3619
examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,397 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim):
|
||||
"""
|
||||
Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation
|
||||
Arguments:
|
||||
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
||||
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
||||
a: (H):(1)
|
||||
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
||||
D: (1, H):(0, 1) or (D, H):(1, D)
|
||||
has_d: bool
|
||||
d_has_hdim: bool
|
||||
Return:
|
||||
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
||||
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
||||
"""
|
||||
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
||||
|
||||
A = delta * a.view(1, 1, -1, 1)
|
||||
X = x * delta.unsqueeze(0)
|
||||
|
||||
# Rearrange to match cutlass layout to tridao's layout
|
||||
block_len = A.shape[0]
|
||||
initial_states = None
|
||||
# A: l c h b-> b c l h
|
||||
A = A.permute(3, 1, 0, 2)
|
||||
# X: p l c h b -> b c l h p
|
||||
X = X.permute(4, 2, 1, 3, 0)
|
||||
# B: l n c g b -> b c l g n
|
||||
B = B.permute(4, 2, 0, 3, 1)
|
||||
# C: l n c g b -> b c l g n
|
||||
C = C.permute(4, 2, 0, 3, 1)
|
||||
# X/A/B/C: b c l ... -> b (c l) ...
|
||||
X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)]
|
||||
|
||||
# Ngroup (g to h) mapping
|
||||
B_val, CL_val, G_val, N_val = B.shape
|
||||
H_val = X.shape[2]
|
||||
ngroup_ratio = H_val // G_val
|
||||
# B/C: (B, CL, H, N)
|
||||
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
||||
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
|
||||
###################################################################
|
||||
# Call reference implementation from Tri Dao ssd_minimal_discrete
|
||||
Y, final_state = ssd_minimal_discrete_fp32_all(
|
||||
X, A, B, C, block_len, initial_states
|
||||
)
|
||||
###################################################################
|
||||
|
||||
if has_d:
|
||||
D_val = Y.shape[3]
|
||||
if not d_has_hdim:
|
||||
D = D.expand(D_val, -1)
|
||||
Y = Y + torch.einsum("bchp,ph->bchp", X, D)
|
||||
|
||||
# Rearrange to match tridao's layout to cutlass layout
|
||||
# Y: b (c l) h p -> b c l h p
|
||||
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
||||
# Y: b c l h p -> l p c h b
|
||||
Y = Y.permute(2, 4, 1, 3, 0)
|
||||
# Fstate_out: b h p n -> p n h b
|
||||
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
||||
Y_out.copy_(Y)
|
||||
return
|
||||
|
||||
|
||||
def ssd_reference_lowprecision_intermediates(
|
||||
x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim
|
||||
):
|
||||
"""
|
||||
Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation
|
||||
Arguments:
|
||||
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
||||
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
||||
a: (H):(1)
|
||||
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
||||
intermediate_dtype: input and intermediate data type
|
||||
D: (1, H):(0, 1) or (D, H):(1, D)
|
||||
has_d: bool
|
||||
d_has_hdim: bool
|
||||
Return:
|
||||
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
||||
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
||||
"""
|
||||
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
||||
|
||||
A = delta * a.view(1, 1, -1, 1)
|
||||
|
||||
# Rearrange to match cutlass layout to tridao's layout
|
||||
block_len = A.shape[0]
|
||||
initial_states = None
|
||||
# A: l c h b-> b c l h
|
||||
A = A.permute(3, 1, 0, 2)
|
||||
# delta: l c h b-> b c l h
|
||||
delta = delta.permute(3, 1, 0, 2)
|
||||
# x: p l c h b -> b c l h p
|
||||
x = x.permute(4, 2, 1, 3, 0)
|
||||
# B: l n c g b -> b c l g n
|
||||
B = B.permute(4, 2, 0, 3, 1)
|
||||
# C: l n c g b -> b c l g n
|
||||
C = C.permute(4, 2, 0, 3, 1)
|
||||
# x/A/delta/B/C: b c l ... -> b (c l) ...
|
||||
x, A, delta, B, C = [
|
||||
tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:])
|
||||
for tensor in (x, A, delta, B, C)
|
||||
]
|
||||
|
||||
# Ngroup (g to h) mapping
|
||||
B_val, CL_val, G_val, N_val = B.shape
|
||||
H_val = x.shape[2]
|
||||
ngroup_ratio = H_val // G_val
|
||||
# B/C: (B, CL, H, N)
|
||||
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
||||
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
||||
|
||||
# Type convert input tensors to input dtype (same as intermediate dtype)
|
||||
x = x.to(intermediate_dtype).to(torch.float32)
|
||||
A = A.to(intermediate_dtype).to(torch.float32)
|
||||
delta = delta.to(intermediate_dtype).to(torch.float32)
|
||||
B = B.to(intermediate_dtype).to(torch.float32)
|
||||
C = C.to(intermediate_dtype).to(torch.float32)
|
||||
|
||||
#########################################################################
|
||||
# Call reference implementation ssd_minimal_discrete_bf16_intermediates
|
||||
Y, final_state = ssd_minimal_discrete_lowprecision_intermediates(
|
||||
x, A, delta, B, C, block_len, intermediate_dtype, initial_states
|
||||
)
|
||||
#########################################################################
|
||||
|
||||
if has_d:
|
||||
D = D.to(intermediate_dtype).to(torch.float32)
|
||||
D_val = Y.shape[3]
|
||||
if not d_has_hdim:
|
||||
D = D.expand(D_val, -1)
|
||||
Y = Y + torch.einsum("bchp,ph->bchp", x, D)
|
||||
|
||||
# Type convert output tensors to output dtype (same as intermediate dtype)
|
||||
Y = Y.to(intermediate_dtype).to(torch.float32)
|
||||
final_state = final_state.to(intermediate_dtype).to(torch.float32)
|
||||
|
||||
# Rearrange to match tridao's layout to cutlass layout
|
||||
# Y: b (c l) h p -> b c l h p
|
||||
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
||||
# Y: b c l h p -> l p c h b
|
||||
Y = Y.permute(2, 4, 1, 3, 0)
|
||||
# Fstate_out: b h p n -> p n h b
|
||||
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
||||
Y_out.copy_(Y)
|
||||
return
|
||||
|
||||
|
||||
def analyze_relative_diffs(actual, expected):
|
||||
"""
|
||||
Print statistics of relative differences between actual and expected tensors
|
||||
"""
|
||||
# Calculate relative differences
|
||||
abs_diff = (actual - expected).abs()
|
||||
rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001)
|
||||
|
||||
total_elements = rel_diff.numel()
|
||||
|
||||
# Handle special cases first
|
||||
nan_mask = torch.isnan(rel_diff)
|
||||
inf_mask = torch.isinf(rel_diff)
|
||||
nan_count = nan_mask.sum().item()
|
||||
inf_count = inf_mask.sum().item()
|
||||
|
||||
# Find position and value of maximum relative difference
|
||||
max_rel_diff = (
|
||||
rel_diff[~nan_mask & ~inf_mask].max()
|
||||
if (~nan_mask & ~inf_mask).any()
|
||||
else float("nan")
|
||||
)
|
||||
max_rel_diff_pos = (
|
||||
rel_diff[~nan_mask & ~inf_mask].argmax()
|
||||
if (~nan_mask & ~inf_mask).any()
|
||||
else -1
|
||||
)
|
||||
|
||||
# Print max relative difference info
|
||||
print(f"Maximum relative difference:")
|
||||
print(f"Position: {max_rel_diff_pos}")
|
||||
print(f"Value: {max_rel_diff:.6e}")
|
||||
print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}")
|
||||
print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}")
|
||||
print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)")
|
||||
print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n")
|
||||
|
||||
# Check different rtol thresholds
|
||||
rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01]
|
||||
|
||||
for i, rtol in enumerate(rtol_levels):
|
||||
if i == 0:
|
||||
mask = rel_diff <= rtol
|
||||
else:
|
||||
mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1])
|
||||
|
||||
count = mask.sum().item()
|
||||
percentage = (count / total_elements) * 100
|
||||
|
||||
if i == 0:
|
||||
print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)")
|
||||
else:
|
||||
print(
|
||||
f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)"
|
||||
)
|
||||
|
||||
# Print elements exceeding the largest rtol
|
||||
mask = rel_diff > rtol_levels[-1]
|
||||
count = mask.sum().item()
|
||||
percentage = (count / total_elements) * 100
|
||||
print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n")
|
||||
|
||||
|
||||
def segsum(x):
|
||||
"""
|
||||
More stable segment sum calculation.
|
||||
x: b h c l
|
||||
"""
|
||||
T = x.size(-1)
|
||||
# x: b h c l -> b h c l l
|
||||
x = x.unsqueeze(-1).expand(*x.shape, T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
|
||||
(all accumulation and intermediate results in fp32)
|
||||
|
||||
Arguments:
|
||||
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
A: (batch(B), length(C*L), n_heads(H))
|
||||
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
Return:
|
||||
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
final_state: (B, H, D, N)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
# X/A/B/C:b (c l) ... -> b c l ...
|
||||
X, A, B, C = [
|
||||
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C)
|
||||
]
|
||||
|
||||
# A: b c l h -> b h c l
|
||||
A = A.permute(0, 3, 1, 2)
|
||||
# A_cumsum: (B, H, C, L)
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
segsum_A = segsum(A)
|
||||
L = torch.exp(segsum_A)
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
# Y: b c l h p -> b (c l) h p
|
||||
Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4])
|
||||
return Y, final_state
|
||||
|
||||
|
||||
def ssd_minimal_discrete_lowprecision_intermediates(
|
||||
X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None
|
||||
):
|
||||
"""
|
||||
This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions:
|
||||
1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype
|
||||
2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation
|
||||
|
||||
Arguments:
|
||||
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
A: (batch(B), length(C*L), n_heads(H))
|
||||
delta: (batch(B), length(C*L), n_heads(H))
|
||||
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
||||
Return:
|
||||
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
||||
final_state: (B, H, D, N)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
# X/A/delta/B/C: b (c l) ... -> b c l ...
|
||||
X, A, delta, B, C = [
|
||||
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C)
|
||||
]
|
||||
|
||||
# A: b c l h -> b h c l
|
||||
A = A.permute(0, 3, 1, 2)
|
||||
# delta: b c l h -> b h c l
|
||||
delta = delta.permute(0, 3, 1, 2)
|
||||
# A_cumsum: (B, H, C, L)
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
segsum_A = segsum(A)
|
||||
L = torch.exp(segsum_A)
|
||||
intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B)
|
||||
Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta)
|
||||
Y_diag = torch.einsum(
|
||||
"bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta)
|
||||
states = torch.einsum(
|
||||
"bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X
|
||||
)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
final_state = final_state
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off_tmp = torch.einsum(
|
||||
"bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32)
|
||||
)
|
||||
Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
# Y: b c l h p -> b (c l) h p
|
||||
Y = (Y_diag + Y_off).reshape(
|
||||
Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]
|
||||
) # b (c l) h p
|
||||
return Y, final_state
|
||||
@ -0,0 +1,200 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Integer,
|
||||
Int32,
|
||||
min,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass.cute as cute
|
||||
from cutlass.utils import WorkTileInfo
|
||||
|
||||
|
||||
class Mamba2SSDTileSchedulerParams:
|
||||
def __init__(
|
||||
self,
|
||||
problem_shape_ntiles: int,
|
||||
eh: int,
|
||||
ngroup_ratio: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
self.problem_shape_ntiles = problem_shape_ntiles
|
||||
self.eh = eh
|
||||
self.ngroup_ratio = ngroup_ratio
|
||||
self._loc = loc
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values, self._values_pos = [], []
|
||||
for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]:
|
||||
obj_values = extract_mlir_values(obj)
|
||||
values += obj_values
|
||||
self._values_pos.append(len(obj_values))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
obj_list = []
|
||||
for obj, n_items in zip(
|
||||
[self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos
|
||||
):
|
||||
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
|
||||
values = values[n_items:]
|
||||
return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
|
||||
|
||||
@dsl_user_op
|
||||
def get_grid_shape(
|
||||
self, max_active_clusters: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1)
|
||||
|
||||
|
||||
class Mamba2SSDTileScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
num_persistent_ctas: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
num_tiles_executed: Int32,
|
||||
):
|
||||
self.params = params
|
||||
self.num_persistent_ctas = num_persistent_ctas
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self._num_tiles_executed = num_tiles_executed
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.num_persistent_ctas)
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self._num_tiles_executed))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: list[ir.Value]
|
||||
) -> "Mamba2SSDTileScheduler":
|
||||
assert len(values) == 3
|
||||
new_num_persistent_ctas = new_from_mlir_values(
|
||||
self.num_persistent_ctas, [values[0]]
|
||||
)
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[1]]
|
||||
)
|
||||
new_num_tiles_executed = new_from_mlir_values(
|
||||
self._num_tiles_executed, [values[2]]
|
||||
)
|
||||
return Mamba2SSDTileScheduler(
|
||||
self.params,
|
||||
new_num_persistent_ctas,
|
||||
new_current_work_linear_idx,
|
||||
new_num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@dsl_user_op
|
||||
@staticmethod
|
||||
def create(
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
params = params
|
||||
|
||||
# Calculate the number of persistent clusters by dividing the total grid size
|
||||
# by the number of CTAs per cluster
|
||||
num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip))
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
|
||||
# Initialize workload index equals to the cluster index in the grid
|
||||
current_work_linear_idx = Int32(bidx)
|
||||
|
||||
# Initialize number of tiles executed to zero
|
||||
num_tiles_executed = Int32(0)
|
||||
return Mamba2SSDTileScheduler(
|
||||
params,
|
||||
num_persistent_ctas,
|
||||
current_work_linear_idx,
|
||||
num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: Mamba2SSDTileSchedulerParams,
|
||||
max_active_clusters: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
|
||||
|
||||
# private method
|
||||
def _get_current_work_for_linear_idx(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> WorkTileInfo:
|
||||
is_valid = current_work_linear_idx < cute.size(
|
||||
self.params.problem_shape_ntiles, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
eh_idx = current_work_linear_idx % self.params.eh
|
||||
b_idx = current_work_linear_idx // self.params.eh
|
||||
g_idx = eh_idx // self.params.ngroup_ratio
|
||||
# cur_tile_coord is (b_idx, eh_idx, g_idx)
|
||||
cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx))
|
||||
|
||||
return WorkTileInfo(cur_tile_coord, is_valid)
|
||||
|
||||
@dsl_user_op
|
||||
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self._get_current_work_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self.get_current_work(loc=loc, ip=ip)
|
||||
|
||||
@dsl_user_op
|
||||
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
|
||||
self._current_work_linear_idx += Int32(advance_count) * Int32(
|
||||
self.num_persistent_ctas
|
||||
)
|
||||
self._num_tiles_executed += Int32(1)
|
||||
|
||||
@property
|
||||
def num_tiles_executed(self) -> Int32:
|
||||
return self._num_tiles_executed
|
||||
Reference in New Issue
Block a user