Compare commits

...

1 Commits

Author SHA1 Message Date
cd3ea013d6 maybe fix
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
2025-09-27 17:49:34 -07:00

View File

@ -167,6 +167,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
@ -209,8 +217,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
lse = lse[:, :H].contiguous(
) if self.need_to_return_lse_for_decode else lse
out = out[:, :H].contiguous()
return out, lse