Compare commits
1 Commits
use-uv-pyt
...
maybe_fix_
| Author | SHA1 | Date | |
|---|---|---|---|
| cd3ea013d6 |
@ -167,6 +167,14 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
MAX_HEADS = 128
|
MAX_HEADS = 128
|
||||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||||
|
if H < MAX_HEADS:
|
||||||
|
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||||
|
q_nope_padded[:, :H] = q_nope
|
||||||
|
q_nope = q_nope_padded
|
||||||
|
|
||||||
|
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||||
|
q_pe_padded[:, :H] = q_pe
|
||||||
|
q_pe = q_pe_padded
|
||||||
|
|
||||||
assert len(page_table.shape) == 2
|
assert len(page_table.shape) == 2
|
||||||
B_block_table, block_num = page_table.shape
|
B_block_table, block_num = page_table.shape
|
||||||
@ -209,8 +217,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
if H < MAX_HEADS:
|
if H < MAX_HEADS:
|
||||||
# Extract the subsets of the outputs
|
# Extract the subsets of the outputs
|
||||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
lse = lse[:, :H].contiguous(
|
||||||
out = out[:, :H]
|
) if self.need_to_return_lse_for_decode else lse
|
||||||
|
out = out[:, :H].contiguous()
|
||||||
|
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user