@ -78,11 +78,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
if self.runner.full_cuda_graph:
|
||||
n = num_splits.size(0)
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_num_splits is None:
|
||||
self.cg_buf_num_splits = num_splits
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
elif n <= self.cg_buf_num_splits.size(0):
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
@ -93,7 +93,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
|
||||
Reference in New Issue
Block a user