v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -510,7 +510,6 @@ class PersistentDenseGemmKernel:
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
return
@ -669,7 +668,7 @@ class PersistentDenseGemmKernel:
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_block_cnt = cute.size(gA_mkl, mode=[3])
k_tile_cnt = cute.size(gA_mkl, mode=[3])
#
# Partition global tensor for TiledMMA_A/B/C
@ -774,17 +773,17 @@ class PersistentDenseGemmKernel:
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
]
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
ab_producer_state.reset_count()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
#
# Tma load loop
#
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
# Conditionally wait for AB buffer empty
ab_pipeline.producer_acquire(
ab_producer_state, peek_ab_empty_status
@ -806,10 +805,10 @@ class PersistentDenseGemmKernel:
mcast_mask=b_full_mcast_mask,
)
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
ab_producer_state.advance()
peek_ab_empty_status = cutlass.Boolean(1)
if ab_producer_state.count < k_block_cnt:
if ab_producer_state.count < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
ab_producer_state
)
@ -877,10 +876,10 @@ class PersistentDenseGemmKernel:
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# Peek (try_wait) AB buffer full for k_block = 0
# Peek (try_wait) AB buffer full for k_tile = 0
ab_consumer_state.reset_count()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state
)
@ -899,7 +898,7 @@ class PersistentDenseGemmKernel:
#
# Mma mainloop
#
for k_block in range(k_block_cnt):
for k_tile in range(k_tile_cnt):
if is_leader_cta:
# Conditionally wait for AB buffer full
ab_pipeline.consumer_wait(
@ -907,32 +906,32 @@ class PersistentDenseGemmKernel:
)
# tCtAcc += tCrA * tCrB
num_kphases = cute.size(tCrA, mode=[2])
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
kphase_coord = (
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (
None,
None,
kphase_idx,
kblock_idx,
ab_consumer_state.index,
)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kphase_coord],
tCrB[kphase_coord],
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
# Enable accumulate on tCtAcc after first kphase
# Enable accumulate on tCtAcc after first kblock
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
ab_pipeline.consumer_release(ab_consumer_state)
# Peek (try_wait) AB buffer full for k_block = k_block + 1
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
ab_consumer_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if ab_consumer_state.count < k_block_cnt:
if ab_consumer_state.count < k_tile_cnt:
if is_leader_cta:
peek_ab_full_status = ab_pipeline.consumer_try_wait(
ab_consumer_state