Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Draft] Jump-forward decoding #15490

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)

Expand Down Expand Up @@ -45,14 +45,36 @@
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
# NOTE: We already setup guided_decoding in sampling_params
# from vllm/v1/processor.py#L135, hence "type: ignore" directive.
backend_name = request.sampling_params.guided_decoding.backend_name # type: ignore[union-attr]
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config)
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
vocab_size = self.vllm_config.model_config.get_vocab_size()

if backend_name == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import (
XgrammarBackend)

self.backend = XgrammarBackend(self.vllm_config)
self.backend = XgrammarBackend(
vllm_config=self.vllm_config,
tokenizer=tokenizer,
vocab_size=vocab_size,
)
elif backend_name == "guidance":
self.backend = GuidanceBackend(self.vllm_config)
from vllm.v1.structured_output.backend_guidance import (
GuidanceBackend)

self.backend = GuidanceBackend(
vllm_config=self.vllm_config,
tokenizer=tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")
Expand Down Expand Up @@ -100,10 +122,10 @@
assert request is not None and request.grammar is not None
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:

Check failure on line 125 in vllm/v1/structured_output/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "shape" [union-attr]
bitmask_tensor = self._grammar_bitmask[:batch_len]

Check failure on line 126 in vllm/v1/structured_output/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[Any]" is not indexable [index]

# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()

Check failure on line 131 in vllm/v1/structured_output/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "numpy" [union-attr]
43 changes: 25 additions & 18 deletions vllm/v1/structured_output/backend_guidance.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
Expand All @@ -29,21 +29,11 @@
logger = init_logger(__name__)


@dataclass
class GuidanceBackend(StructuredOutputBackend):

def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
def __post_init__(self):
self.ll_tokenizer = llguidance_hf.from_tokenizer(self.tokenizer, None)

def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
Expand All @@ -69,6 +59,14 @@
return llguidance_torch.allocate_token_bitmask(
max_num_seqs, self.ll_tokenizer.vocab_size)

def encode_with_jump(

Check failure on line 62 in vllm/v1/structured_output/backend_guidance.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]

Check failure on line 62 in vllm/v1/structured_output/backend_guidance.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]
self,
output_token_ids: list[int],
jump_forward_string: str,
) -> list[int]:
# TO BE IMPLEMENTED
pass


@dataclass
class GuidanceGrammar(StructuredOutputGrammar):
Expand Down Expand Up @@ -120,9 +118,18 @@
def is_terminated(self) -> bool:
return self.terminated

def reset(self):
# This method may be not needed anymore? TODO
self.ll_matcher.reset()
def find_token_divergence(

Check failure on line 121 in vllm/v1/structured_output/backend_guidance.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]

Check failure on line 121 in vllm/v1/structured_output/backend_guidance.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]
self,
request_id: str,
prev_tokens: list[int],
combined_tokens: list[int],
) -> int:
# TO BE IMPLEMENTED
pass

def jump_forward_string(self) -> str | None:
# TO BE IMPLEMENTED
return

Check failure on line 132 in vllm/v1/structured_output/backend_guidance.py

View workflow job for this annotation

GitHub Actions / pre-commit

Return value expected [return-value]


def serialize_guidance_grammar(request_type: StructuredOutputOptions,
Expand Down
67 changes: 63 additions & 4 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch

from vllm.config import VllmConfig

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer


class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
Expand All @@ -20,6 +30,15 @@ class StructuredOutputOptions(enum.Enum):
class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""

@abstractmethod
def jump_forward_string(self) -> str | None:
"""
Get jump forward string and returns its tokens and string

Returns:
str: Optional jump forward string
"""

@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Expand Down Expand Up @@ -54,14 +73,35 @@ def is_terminated(self) -> bool:
"""

@abstractmethod
def reset(self):
def find_token_divergence(
self,
request_id: str,
prev_tokens: list[int],
combined_tokens: list[int],
) -> int:
"""
Resets the state of the structured output grammar.
Finds the index where two token sequences diverge.
Note that each grammar should handle its FSM rollback accordingly.

Args:
request_id: The unique identifier for the request.
prev_tokens: Original token sequence
combined_tokens: New token sequence that
should start with prev_tokens

Returns:
int: Index where the sequences diverge
"""


@dataclass
class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
"""Engine-level backend for structured output requests.

Make sure that all subclasses are also dataclasses."""
vllm_config: VllmConfig
tokenizer: AnyTokenizer
vocab_size: int

@abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions,
Expand All @@ -79,11 +119,30 @@ def compile_grammar(self, request_type: StructuredOutputOptions,
"""

@abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int):
def allocate_token_bitmask(self, max_num_seqs: int) -> None:
"""
Allocates a token bitmask for the specified maximum number of sequences.

Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""

@abstractmethod
def encode_with_jump(
self,
output_token_ids: list[int],
jump_forward_string: str,
) -> list[int]:
"""
Handle retokenization with the jump forward string and
returns the new tokens and the number of previous tokens to replace.

Args:
request_id: The unique identifier for the request.
jump_forward_string: The string to jump forward with

Returns:
list[int]: Returns list of new tokens
including the jump forward string.
"""
84 changes: 46 additions & 38 deletions vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import itertools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
Expand All @@ -22,42 +23,34 @@
logger = init_logger(__name__)


@dataclass
class XgrammarBackend(StructuredOutputBackend):

def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
def __post_init__(self):
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
in self.vllm_config.decoding_config.guided_decoding_backend)

if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
self.tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
) and self.tokenizer.eos_token_id is not None:
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
Expand All @@ -69,7 +62,7 @@
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
self.tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
Expand All @@ -93,32 +86,30 @@
f"grammar is not of valid supported types. ({request_type!s})")

return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx),
# NOTE: conservatively fixed this value for now
# given that the length of jf-forward string in theory
# won't have a limit
matcher=xgr.GrammarMatcher(ctx, max_rollback_tokens=50),
vocab_size=self.vocab_size,
ctx=ctx,
)

def allocate_token_bitmask(self, max_num_seqs: int):
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

def encode_with_jump(

Check failure on line 100 in vllm/v1/structured_output/backend_xgrammar.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]

Check failure on line 100 in vllm/v1/structured_output/backend_xgrammar.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing return statement [empty-body]
self,
output_token_ids: list[int],
jump_forward_string: str,
) -> list[int]:
...


@dataclass
class XgrammarGrammar(StructuredOutputGrammar):
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding

vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)

def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Expand All @@ -132,7 +123,6 @@
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True

def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
Expand All @@ -141,6 +131,24 @@
def is_terminated(self) -> bool:
return self.matcher.is_terminated()

def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def jump_forward_string(self) -> str | None:
jf_string = self.matcher.find_jump_forward_string()
return jf_string if jf_string else None

def find_token_divergence(
self,
request_id: str,
prev_tokens: list[int],
combined_tokens: list[int],
) -> int:
min_len = min(len(prev_tokens), len(combined_tokens))
k = sum(1 for _ in itertools.takewhile(
lambda x: x[0] == x[1],
zip(prev_tokens[:(min_len)], combined_tokens[:min_len]),
))

# We have to rollback the tokens to the divergence point
if k < len(prev_tokens):
self.matcher.rollback(len(prev_tokens) - k)
assert self.accept_tokens(request_id, combined_tokens[k:])
return k