v4.2 tag release. (#2638)
This commit is contained in:
@ -510,7 +510,6 @@ class PersistentDenseGemmKernel:
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
return
|
||||
@ -669,7 +668,7 @@ class PersistentDenseGemmKernel:
|
||||
gC_mnl = cute.local_tile(
|
||||
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
|
||||
)
|
||||
k_block_cnt = cute.size(gA_mkl, mode=[3])
|
||||
k_tile_cnt = cute.size(gA_mkl, mode=[3])
|
||||
|
||||
#
|
||||
# Partition global tensor for TiledMMA_A/B/C
|
||||
@ -774,17 +773,17 @@ class PersistentDenseGemmKernel:
|
||||
(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])
|
||||
]
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
|
||||
ab_producer_state.reset_count()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
#
|
||||
# Tma load loop
|
||||
#
|
||||
for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1):
|
||||
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
|
||||
# Conditionally wait for AB buffer empty
|
||||
ab_pipeline.producer_acquire(
|
||||
ab_producer_state, peek_ab_empty_status
|
||||
@ -806,10 +805,10 @@ class PersistentDenseGemmKernel:
|
||||
mcast_mask=b_full_mcast_mask,
|
||||
)
|
||||
|
||||
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1
|
||||
# Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
|
||||
ab_producer_state.advance()
|
||||
peek_ab_empty_status = cutlass.Boolean(1)
|
||||
if ab_producer_state.count < k_block_cnt:
|
||||
if ab_producer_state.count < k_tile_cnt:
|
||||
peek_ab_empty_status = ab_pipeline.producer_try_acquire(
|
||||
ab_producer_state
|
||||
)
|
||||
@ -877,10 +876,10 @@ class PersistentDenseGemmKernel:
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = 0
|
||||
# Peek (try_wait) AB buffer full for k_tile = 0
|
||||
ab_consumer_state.reset_count()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt and is_leader_cta:
|
||||
if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
)
|
||||
@ -899,7 +898,7 @@ class PersistentDenseGemmKernel:
|
||||
#
|
||||
# Mma mainloop
|
||||
#
|
||||
for k_block in range(k_block_cnt):
|
||||
for k_tile in range(k_tile_cnt):
|
||||
if is_leader_cta:
|
||||
# Conditionally wait for AB buffer full
|
||||
ab_pipeline.consumer_wait(
|
||||
@ -907,32 +906,32 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
num_kphases = cute.size(tCrA, mode=[2])
|
||||
for kphase_idx in cutlass.range(num_kphases, unroll_full=True):
|
||||
kphase_coord = (
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (
|
||||
None,
|
||||
None,
|
||||
kphase_idx,
|
||||
kblock_idx,
|
||||
ab_consumer_state.index,
|
||||
)
|
||||
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[kphase_coord],
|
||||
tCrB[kphase_coord],
|
||||
tCrA[kblock_coord],
|
||||
tCrB[kblock_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
# Enable accumulate on tCtAcc after first kphase
|
||||
# Enable accumulate on tCtAcc after first kblock
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Async arrive AB buffer empty
|
||||
ab_pipeline.consumer_release(ab_consumer_state)
|
||||
|
||||
# Peek (try_wait) AB buffer full for k_block = k_block + 1
|
||||
# Peek (try_wait) AB buffer full for k_tile = k_tile + 1
|
||||
ab_consumer_state.advance()
|
||||
peek_ab_full_status = cutlass.Boolean(1)
|
||||
if ab_consumer_state.count < k_block_cnt:
|
||||
if ab_consumer_state.count < k_tile_cnt:
|
||||
if is_leader_cta:
|
||||
peek_ab_full_status = ab_pipeline.consumer_try_wait(
|
||||
ab_consumer_state
|
||||
|
||||
Reference in New Issue
Block a user