398 lines
16 KiB
Python
398 lines
16 KiB
Python
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim):
|
|
"""
|
|
Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation
|
|
Arguments:
|
|
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
|
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
|
a: (H):(1)
|
|
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
|
D: (1, H):(0, 1) or (D, H):(1, D)
|
|
has_d: bool
|
|
d_has_hdim: bool
|
|
Return:
|
|
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
|
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
|
"""
|
|
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
|
|
|
A = delta * a.view(1, 1, -1, 1)
|
|
X = x * delta.unsqueeze(0)
|
|
|
|
# Rearrange to match cutlass layout to tridao's layout
|
|
block_len = A.shape[0]
|
|
initial_states = None
|
|
# A: l c h b-> b c l h
|
|
A = A.permute(3, 1, 0, 2)
|
|
# X: p l c h b -> b c l h p
|
|
X = X.permute(4, 2, 1, 3, 0)
|
|
# B: l n c g b -> b c l g n
|
|
B = B.permute(4, 2, 0, 3, 1)
|
|
# C: l n c g b -> b c l g n
|
|
C = C.permute(4, 2, 0, 3, 1)
|
|
# X/A/B/C: b c l ... -> b (c l) ...
|
|
X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)]
|
|
|
|
# Ngroup (g to h) mapping
|
|
B_val, CL_val, G_val, N_val = B.shape
|
|
H_val = X.shape[2]
|
|
ngroup_ratio = H_val // G_val
|
|
# B/C: (B, CL, H, N)
|
|
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
|
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
|
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
|
|
|
###################################################################
|
|
# Call reference implementation from Tri Dao ssd_minimal_discrete
|
|
Y, final_state = ssd_minimal_discrete_fp32_all(
|
|
X, A, B, C, block_len, initial_states
|
|
)
|
|
###################################################################
|
|
|
|
if has_d:
|
|
D_val = Y.shape[3]
|
|
if not d_has_hdim:
|
|
D = D.expand(D_val, -1)
|
|
Y = Y + torch.einsum("bchp,ph->bchp", X, D)
|
|
|
|
# Rearrange to match tridao's layout to cutlass layout
|
|
# Y: b (c l) h p -> b c l h p
|
|
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
|
# Y: b c l h p -> l p c h b
|
|
Y = Y.permute(2, 4, 1, 3, 0)
|
|
# Fstate_out: b h p n -> p n h b
|
|
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
|
Y_out.copy_(Y)
|
|
return
|
|
|
|
|
|
def ssd_reference_lowprecision_intermediates(
|
|
x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim
|
|
):
|
|
"""
|
|
Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation
|
|
Arguments:
|
|
X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L)
|
|
A/delta: (L, C, H, B):(1, L, C*L, H*C*L)
|
|
a: (H):(1)
|
|
B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L)
|
|
intermediate_dtype: input and intermediate data type
|
|
D: (1, H):(0, 1) or (D, H):(1, D)
|
|
has_d: bool
|
|
d_has_hdim: bool
|
|
Return:
|
|
Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L)
|
|
Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N)
|
|
"""
|
|
assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype
|
|
|
|
A = delta * a.view(1, 1, -1, 1)
|
|
|
|
# Rearrange to match cutlass layout to tridao's layout
|
|
block_len = A.shape[0]
|
|
initial_states = None
|
|
# A: l c h b-> b c l h
|
|
A = A.permute(3, 1, 0, 2)
|
|
# delta: l c h b-> b c l h
|
|
delta = delta.permute(3, 1, 0, 2)
|
|
# x: p l c h b -> b c l h p
|
|
x = x.permute(4, 2, 1, 3, 0)
|
|
# B: l n c g b -> b c l g n
|
|
B = B.permute(4, 2, 0, 3, 1)
|
|
# C: l n c g b -> b c l g n
|
|
C = C.permute(4, 2, 0, 3, 1)
|
|
# x/A/delta/B/C: b c l ... -> b (c l) ...
|
|
x, A, delta, B, C = [
|
|
tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:])
|
|
for tensor in (x, A, delta, B, C)
|
|
]
|
|
|
|
# Ngroup (g to h) mapping
|
|
B_val, CL_val, G_val, N_val = B.shape
|
|
H_val = x.shape[2]
|
|
ngroup_ratio = H_val // G_val
|
|
# B/C: (B, CL, H, N)
|
|
h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio
|
|
B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
|
C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val))
|
|
|
|
# Type convert input tensors to input dtype (same as intermediate dtype)
|
|
x = x.to(intermediate_dtype).to(torch.float32)
|
|
A = A.to(intermediate_dtype).to(torch.float32)
|
|
delta = delta.to(intermediate_dtype).to(torch.float32)
|
|
B = B.to(intermediate_dtype).to(torch.float32)
|
|
C = C.to(intermediate_dtype).to(torch.float32)
|
|
|
|
#########################################################################
|
|
# Call reference implementation ssd_minimal_discrete_bf16_intermediates
|
|
Y, final_state = ssd_minimal_discrete_lowprecision_intermediates(
|
|
x, A, delta, B, C, block_len, intermediate_dtype, initial_states
|
|
)
|
|
#########################################################################
|
|
|
|
if has_d:
|
|
D = D.to(intermediate_dtype).to(torch.float32)
|
|
D_val = Y.shape[3]
|
|
if not d_has_hdim:
|
|
D = D.expand(D_val, -1)
|
|
Y = Y + torch.einsum("bchp,ph->bchp", x, D)
|
|
|
|
# Type convert output tensors to output dtype (same as intermediate dtype)
|
|
Y = Y.to(intermediate_dtype).to(torch.float32)
|
|
final_state = final_state.to(intermediate_dtype).to(torch.float32)
|
|
|
|
# Rearrange to match tridao's layout to cutlass layout
|
|
# Y: b (c l) h p -> b c l h p
|
|
Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3])
|
|
# Y: b c l h p -> l p c h b
|
|
Y = Y.permute(2, 4, 1, 3, 0)
|
|
# Fstate_out: b h p n -> p n h b
|
|
Fstate_out.copy_(final_state.permute(2, 3, 1, 0))
|
|
Y_out.copy_(Y)
|
|
return
|
|
|
|
|
|
def analyze_relative_diffs(actual, expected):
|
|
"""
|
|
Print statistics of relative differences between actual and expected tensors
|
|
"""
|
|
# Calculate relative differences
|
|
abs_diff = (actual - expected).abs()
|
|
rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001)
|
|
|
|
total_elements = rel_diff.numel()
|
|
|
|
# Handle special cases first
|
|
nan_mask = torch.isnan(rel_diff)
|
|
inf_mask = torch.isinf(rel_diff)
|
|
nan_count = nan_mask.sum().item()
|
|
inf_count = inf_mask.sum().item()
|
|
|
|
# Find position and value of maximum relative difference
|
|
max_rel_diff = (
|
|
rel_diff[~nan_mask & ~inf_mask].max()
|
|
if (~nan_mask & ~inf_mask).any()
|
|
else float("nan")
|
|
)
|
|
max_rel_diff_pos = (
|
|
rel_diff[~nan_mask & ~inf_mask].argmax()
|
|
if (~nan_mask & ~inf_mask).any()
|
|
else -1
|
|
)
|
|
|
|
# Print max relative difference info
|
|
print(f"Maximum relative difference:")
|
|
print(f"Position: {max_rel_diff_pos}")
|
|
print(f"Value: {max_rel_diff:.6e}")
|
|
print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}")
|
|
print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}")
|
|
print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)")
|
|
print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n")
|
|
|
|
# Check different rtol thresholds
|
|
rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01]
|
|
|
|
for i, rtol in enumerate(rtol_levels):
|
|
if i == 0:
|
|
mask = rel_diff <= rtol
|
|
else:
|
|
mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1])
|
|
|
|
count = mask.sum().item()
|
|
percentage = (count / total_elements) * 100
|
|
|
|
if i == 0:
|
|
print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)")
|
|
else:
|
|
print(
|
|
f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)"
|
|
)
|
|
|
|
# Print elements exceeding the largest rtol
|
|
mask = rel_diff > rtol_levels[-1]
|
|
count = mask.sum().item()
|
|
percentage = (count / total_elements) * 100
|
|
print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n")
|
|
|
|
|
|
def segsum(x):
|
|
"""
|
|
More stable segment sum calculation.
|
|
x: b h c l
|
|
"""
|
|
T = x.size(-1)
|
|
# x: b h c l -> b h c l l
|
|
x = x.unsqueeze(-1).expand(*x.shape, T)
|
|
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
|
x = x.masked_fill(~mask, 0)
|
|
x_segsum = torch.cumsum(x, dim=-2)
|
|
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
|
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
|
return x_segsum
|
|
|
|
|
|
def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None):
|
|
"""
|
|
This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py
|
|
(all accumulation and intermediate results in fp32)
|
|
|
|
Arguments:
|
|
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
|
A: (batch(B), length(C*L), n_heads(H))
|
|
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
|
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
|
Return:
|
|
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
|
final_state: (B, H, D, N)
|
|
"""
|
|
assert X.dtype == A.dtype == B.dtype == C.dtype
|
|
assert X.shape[1] % block_len == 0
|
|
|
|
# Rearrange into blocks/chunks
|
|
# X/A/B/C:b (c l) ... -> b c l ...
|
|
X, A, B, C = [
|
|
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C)
|
|
]
|
|
|
|
# A: b c l h -> b h c l
|
|
A = A.permute(0, 3, 1, 2)
|
|
# A_cumsum: (B, H, C, L)
|
|
A_cumsum = torch.cumsum(A, dim=-1)
|
|
|
|
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
|
segsum_A = segsum(A)
|
|
L = torch.exp(segsum_A)
|
|
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
|
|
|
# 2. Compute the state for each intra-chunk
|
|
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
|
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
|
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
|
|
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
|
# (middle term of factorization of off-diag blocks; A terms)
|
|
if initial_states is None:
|
|
initial_states = torch.zeros_like(states[:, :1])
|
|
states = torch.cat([initial_states, states], dim=1)
|
|
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
|
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
|
states, final_state = new_states[:, :-1], new_states[:, -1]
|
|
|
|
# 4. Compute state -> output conversion per chunk
|
|
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
|
state_decay_out = torch.exp(A_cumsum)
|
|
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
|
|
|
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
|
# Y: b c l h p -> b (c l) h p
|
|
Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4])
|
|
return Y, final_state
|
|
|
|
|
|
def ssd_minimal_discrete_lowprecision_intermediates(
|
|
X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None
|
|
):
|
|
"""
|
|
This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions:
|
|
1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype
|
|
2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation
|
|
|
|
Arguments:
|
|
X: (batch(B), length(C*L), n_heads(H), d_head(D))
|
|
A: (batch(B), length(C*L), n_heads(H))
|
|
delta: (batch(B), length(C*L), n_heads(H))
|
|
B: (batch(B), length(C*L), n_heads(H), d_state(N))
|
|
C: (batch(B), length(C*L), n_heads(H), d_state(N))
|
|
Return:
|
|
Y: (batch(B), length(C*L), n_heads(H), d_head(D))
|
|
final_state: (B, H, D, N)
|
|
"""
|
|
assert X.dtype == A.dtype == B.dtype == C.dtype
|
|
assert X.shape[1] % block_len == 0
|
|
|
|
# Rearrange into blocks/chunks
|
|
# X/A/delta/B/C: b (c l) ... -> b c l ...
|
|
X, A, delta, B, C = [
|
|
x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C)
|
|
]
|
|
|
|
# A: b c l h -> b h c l
|
|
A = A.permute(0, 3, 1, 2)
|
|
# delta: b c l h -> b h c l
|
|
delta = delta.permute(0, 3, 1, 2)
|
|
# A_cumsum: (B, H, C, L)
|
|
A_cumsum = torch.cumsum(A, dim=-1)
|
|
|
|
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
|
segsum_A = segsum(A)
|
|
L = torch.exp(segsum_A)
|
|
intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B)
|
|
Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta)
|
|
Y_diag = torch.einsum(
|
|
"bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X
|
|
)
|
|
|
|
# 2. Compute the state for each intra-chunk
|
|
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
|
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
|
b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta)
|
|
states = torch.einsum(
|
|
"bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X
|
|
)
|
|
|
|
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
|
# (middle term of factorization of off-diag blocks; A terms)
|
|
if initial_states is None:
|
|
initial_states = torch.zeros_like(states[:, :1])
|
|
states = torch.cat([initial_states, states], dim=1)
|
|
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
|
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
|
states, final_state = new_states[:, :-1], new_states[:, -1]
|
|
final_state = final_state
|
|
|
|
# 4. Compute state -> output conversion per chunk
|
|
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
|
state_decay_out = torch.exp(A_cumsum)
|
|
Y_off_tmp = torch.einsum(
|
|
"bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32)
|
|
)
|
|
Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out)
|
|
|
|
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
|
# Y: b c l h p -> b (c l) h p
|
|
Y = (Y_diag + Y_off).reshape(
|
|
Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]
|
|
) # b (c l) h p
|
|
return Y, final_state
|