Compare commits

...

3 Commits

Author SHA1 Message Date
bcf3c8230d Merge branch 'main' into woosuk-jf 2025-05-04 11:16:07 -07:00
a01af39aa8 Merge branch 'main' into woosuk-jf 2025-05-03 10:42:43 -07:00
eeb5761cf1 Implement Jump-Forward (Fast-Forwrd) Decoding
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-01 18:08:52 -07:00
2 changed files with 39 additions and 5 deletions

View File

@ -721,15 +721,29 @@ class Scheduler(SchedulerInterface):
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
jump_tokens = []
if new_token_ids and request.use_structured_output: if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request assert request.structured_output_request is not None
# should not be None if use_structured_output, we have assert request.structured_output_request.grammar is not None
# check above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens(
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids) req_id, new_token_ids)
if not stopped:
jump_tokens = request.structured_output_request.grammar.jump_forward(
req_id)
for token_id in jump_tokens:
request.append_output_token_ids(token_id)
new_token_ids.append(token_id)
stopped = check_stop(request, self.max_model_len)
if stopped:
break
if jump_tokens:
print(f"jump_tokens: {jump_tokens}")
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if jump_tokens:
request.spec_token_ids.clear()
elif spec_token_ids is not None:
if request.use_structured_output: if request.use_structured_output:
metadata = request.structured_output_request metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None assert metadata is not None and metadata.grammar is not None

View File

@ -170,6 +170,26 @@ class XgrammarGrammar(StructuredOutputGrammar):
self.num_processed_tokens += 1 self.num_processed_tokens += 1
return True return True
def jump_forward(
self,
request_id: str,
) -> list[int]:
bitmask = xgr.allocate_token_bitmask(1, self.vocab_size)
jump_forward_tokens: list[int] = []
while not self.is_terminated():
self.fill_bitmask(bitmask, 0)
is_single, unique_token_id = xgr.testing._is_single_token_bitmask(
bitmask,
vocab_size=self.vocab_size,
index=0,
)
if not is_single:
break
self.accept_tokens(request_id, [unique_token_id])
jump_forward_tokens.append(unique_token_id)
return jump_forward_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]: def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence. """Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM. Will not advance the FSM.