@ -254,8 +254,8 @@ def compute_logprobs(
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute the logprobs of the topk + 1
|
||||
# tokens.
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
BLOCK_SIZE = 1024
|
||||
_topk_logprobs_kernel[(batch_size, )](
|
||||
logprobs,
|
||||
|
||||
Reference in New Issue
Block a user