Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from xgrammar import (
|
||||
@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
|
||||
class XGrammarGrammar(BaseGrammarObject):
|
||||
|
||||
def __init__(
|
||||
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
||||
self,
|
||||
matcher: GrammarMatcher,
|
||||
vocab_size: int,
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
) -> None:
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||
matcher = GrammarMatcher(
|
||||
self.ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
||||
)
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=vocab_size
|
||||
)
|
||||
override_stop_tokens = None
|
||||
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
|
||||
@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||
|
||||
def reset(self):
|
||||
if self.grammar_compiler:
|
||||
|
||||
Reference in New Issue
Block a user