Fix oom
This commit is contained in:
@ -116,9 +116,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user