Compare commits
1 Commits
main
...
maybe_fix_
| Author | SHA1 | Date | |
|---|---|---|---|
| cd3ea013d6 |
@ -167,6 +167,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
@ -209,8 +217,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H].contiguous()
|
||||
|
||||
return out, lse
|
||||
|
||||
|
||||
Reference in New Issue
Block a user