Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -15,7 +15,7 @@
import json
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple, Union
import torch
from xgrammar import (
@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject):
def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
def accept_token(self, token: int):
@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
)
class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def reset(self):
if self.grammar_compiler: