@ -206,12 +206,12 @@ def _topk_log_softmax_kernel(
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = -float("inf")
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=-float("inf"))
|
||||
other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(l, max_val))
|
||||
|
||||
se = 0.0
|
||||
|
||||
Reference in New Issue
Block a user