fix: do not wrap invalid grammar objects during constrained generation (#11328)
This commit is contained in:
@@ -17,7 +17,11 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
|
||||
from .base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
|
||||
|
||||
class ReasonerGrammarObject(BaseGrammarObject):
|
||||
@@ -81,10 +85,9 @@ class ReasonerGrammarBackend(BaseGrammarBackend):
|
||||
self.grammar_backend = grammar_backend
|
||||
self.think_end_id = think_end_id
|
||||
|
||||
def _init_value_dispatch(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[ReasonerGrammarObject]:
|
||||
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
ret = self.grammar_backend._init_value_dispatch(key)
|
||||
if ret is None:
|
||||
return None
|
||||
# avoid wrapping invalid grammar, so that the scheduler can detect it
|
||||
if ret is None or ret is INVALID_GRAMMAR_OBJ:
|
||||
return ret
|
||||
return ReasonerGrammarObject(ret, self.think_end_id)
|
||||
|
||||
Reference in New Issue
Block a user