diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 6c6a8a7bce3..9e86763a544 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,7 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) @@ -45,14 +45,36 @@ def grammar_init(self, request: Request) -> None: # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: - backend_name = request.sampling_params.guided_decoding.backend_name + # NOTE: We already setup guided_decoding in sampling_params + # from vllm/v1/processor.py#L135, hence "type: ignore" directive. + backend_name = request.sampling_params.guided_decoding.backend_name # type: ignore[union-attr] + tokenizer_group = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + parallel_config=self.vllm_config.parallel_config, + lora_config=self.vllm_config.lora_config) + tokenizer_group.ping() + tokenizer = tokenizer_group.get_lora_tokenizer(None) + vocab_size = self.vllm_config.model_config.get_vocab_size() + if backend_name == "xgrammar": from vllm.v1.structured_output.backend_xgrammar import ( XgrammarBackend) - self.backend = XgrammarBackend(self.vllm_config) + self.backend = XgrammarBackend( + vllm_config=self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) elif backend_name == "guidance": - self.backend = GuidanceBackend(self.vllm_config) + from vllm.v1.structured_output.backend_guidance import ( + GuidanceBackend) + + self.backend = GuidanceBackend( + vllm_config=self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend_name}") diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 1e274ad0ae6..c3c8af03a16 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,15 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, @@ -29,21 +29,11 @@ logger = init_logger(__name__) +@dataclass class GuidanceBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() - self.vllm_config = vllm_config - self.vocab_size = vllm_config.model_config.get_vocab_size() - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) + def __post_init__(self): + self.ll_tokenizer = llguidance_hf.from_tokenizer(self.tokenizer, None) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: @@ -69,6 +59,14 @@ def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( max_num_seqs, self.ll_tokenizer.vocab_size) + def encode_with_jump( + self, + output_token_ids: list[int], + jump_forward_string: str, + ) -> list[int]: + # TO BE IMPLEMENTED + pass + @dataclass class GuidanceGrammar(StructuredOutputGrammar): @@ -120,9 +118,18 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: def is_terminated(self) -> bool: return self.terminated - def reset(self): - # This method may be not needed anymore? TODO - self.ll_matcher.reset() + def find_token_divergence( + self, + request_id: str, + prev_tokens: list[int], + combined_tokens: list[int], + ) -> int: + # TO BE IMPLEMENTED + pass + + def jump_forward_string(self) -> str | None: + # TO BE IMPLEMENTED + return def serialize_guidance_grammar(request_type: StructuredOutputOptions, diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 6dc2a92411d..cbfb6ffe7ab 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,10 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch +from vllm.config import VllmConfig + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.transformers_utils.tokenizer import AnyTokenizer + class StructuredOutputOptions(enum.Enum): JSON = enum.auto() @@ -20,6 +30,15 @@ class StructuredOutputOptions(enum.Enum): class StructuredOutputGrammar(ABC): """Request-level backend for structured output requests.""" + @abstractmethod + def jump_forward_string(self) -> str | None: + """ + Get jump forward string and returns its tokens and string + + Returns: + str: Optional jump forward string + """ + @abstractmethod def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """ @@ -54,14 +73,35 @@ def is_terminated(self) -> bool: """ @abstractmethod - def reset(self): + def find_token_divergence( + self, + request_id: str, + prev_tokens: list[int], + combined_tokens: list[int], + ) -> int: """ - Resets the state of the structured output grammar. + Finds the index where two token sequences diverge. + Note that each grammar should handle its FSM rollback accordingly. + + Args: + request_id: The unique identifier for the request. + prev_tokens: Original token sequence + combined_tokens: New token sequence that + should start with prev_tokens + + Returns: + int: Index where the sequences diverge """ +@dataclass class StructuredOutputBackend(ABC): - """Engine-level backend for structured output requests.""" + """Engine-level backend for structured output requests. + + Make sure that all subclasses are also dataclasses.""" + vllm_config: VllmConfig + tokenizer: AnyTokenizer + vocab_size: int @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, @@ -79,7 +119,7 @@ def compile_grammar(self, request_type: StructuredOutputOptions, """ @abstractmethod - def allocate_token_bitmask(self, max_num_seqs: int): + def allocate_token_bitmask(self, max_num_seqs: int) -> None: """ Allocates a token bitmask for the specified maximum number of sequences. @@ -87,3 +127,22 @@ def allocate_token_bitmask(self, max_num_seqs: int): max_num_seqs (int): The maximum number of sequences for which to allocate the bitmask. """ + + @abstractmethod + def encode_with_jump( + self, + output_token_ids: list[int], + jump_forward_string: str, + ) -> list[int]: + """ + Handle retokenization with the jump forward string and + returns the new tokens and the number of previous tokens to replace. + + Args: + request_id: The unique identifier for the request. + jump_forward_string: The string to jump forward with + + Returns: + list[int]: Returns list of new tokens + including the jump forward string. + """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 9bfb644c580..109ddb1deb8 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import itertools from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch -from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -22,42 +23,34 @@ logger = init_logger(__name__) +@dataclass class XgrammarBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config + def __post_init__(self): self.disable_any_whitespace = ( "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): + in self.vllm_config.decoding_config.guided_decoding_backend) + + if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 try: encoded_vocab = [ token for token, _ in sorted( - tokenizer.get_vocab().items(), + self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None if hasattr( - tokenizer, + self.tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + ) and self.tokenizer.eos_token_id is not None: + stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " + f"{type(self.tokenizer)}. The tokenizer should have a " "get_vocab method.") from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, @@ -69,7 +62,7 @@ def __init__(self, vllm_config: VllmConfig): ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, + self.tokenizer, vocab_size=self.vocab_size, ) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) @@ -93,7 +86,10 @@ def compile_grammar(self, request_type: StructuredOutputOptions, f"grammar is not of valid supported types. ({request_type!s})") return XgrammarGrammar( - matcher=xgr.GrammarMatcher(ctx), + # NOTE: conservatively fixed this value for now + # given that the length of jf-forward string in theory + # won't have a limit + matcher=xgr.GrammarMatcher(ctx, max_rollback_tokens=50), vocab_size=self.vocab_size, ctx=ctx, ) @@ -101,24 +97,19 @@ def compile_grammar(self, request_type: StructuredOutputOptions, def allocate_token_bitmask(self, max_num_seqs: int): return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size) + def encode_with_jump( + self, + output_token_ids: list[int], + jump_forward_string: str, + ) -> list[int]: + ... + @dataclass class XgrammarGrammar(StructuredOutputGrammar): - # NOTE: This would be a generic-enough class for - # supporting different backends, in the future. - # For now, just xgrammar. - # - # TODO: support max_rollback_tokens - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string - # for jump-forward decoding - vocab_size: int matcher: xgr.GrammarMatcher = field(hash=False) ctx: xgr.CompiledGrammar = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -132,7 +123,6 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: "Failed to advance FSM for request %s " "for tokens %s. Please file an issue.", request_id, token) return False - self.num_processed_tokens += 1 return True def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: @@ -141,6 +131,24 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: def is_terminated(self) -> bool: return self.matcher.is_terminated() - def reset(self): - self.num_processed_tokens = 0 - self.matcher.reset() + def jump_forward_string(self) -> str | None: + jf_string = self.matcher.find_jump_forward_string() + return jf_string if jf_string else None + + def find_token_divergence( + self, + request_id: str, + prev_tokens: list[int], + combined_tokens: list[int], + ) -> int: + min_len = min(len(prev_tokens), len(combined_tokens)) + k = sum(1 for _ in itertools.takewhile( + lambda x: x[0] == x[1], + zip(prev_tokens[:(min_len)], combined_tokens[:min_len]), + )) + + # We have to rollback the tokens to the divergence point + if k < len(prev_tokens): + self.matcher.rollback(len(prev_tokens) - k) + assert self.accept_tokens(request_id, combined_tokens[k:]) + return k