fix: do not wrap invalid grammar objects during constrained generation (#11328)
This commit is contained in:
@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
|
from .base_grammar_backend import (
|
||||||
|
INVALID_GRAMMAR_OBJ,
|
||||||
|
BaseGrammarBackend,
|
||||||
|
BaseGrammarObject,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReasonerGrammarObject(BaseGrammarObject):
|
class ReasonerGrammarObject(BaseGrammarObject):
|
||||||
@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
|
|||||||
self.grammar_backend = grammar_backend
|
self.grammar_backend = grammar_backend
|
||||||
self.think_end_id = think_end_id
|
self.think_end_id = think_end_id
|
||||||
|
|
||||||
def _init_value_dispatch(
|
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
self, key: Tuple[str, str]
|
|
||||||
) -> Optional[ReasonerGrammarObject]:
|
|
||||||
ret = self.grammar_backend._init_value_dispatch(key)
|
ret = self.grammar_backend._init_value_dispatch(key)
|
||||||
if ret is None:
|
# avoid wrapping invalid grammar, so that the scheduler can detect it
|
||||||
return None
|
if ret is None or ret is INVALID_GRAMMAR_OBJ:
|
||||||
|
return ret
|
||||||
return ReasonerGrammarObject(ret, self.think_end_id)
|
return ReasonerGrammarObject(ret, self.think_end_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user