Add reasoning parser mechanism + qwen3 parser + bugfixes
This commit is contained in:
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Reasoning parser module for vLLM 0.6.3 (BI-V100 / Qwen3.6-27B adaptation).
|
||||
|
||||
Usage: --reasoning-parser qwen3
|
||||
"""
|
||||
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
|
||||
__all__ = ["ReasoningParser", "ReasoningParserManager"]
|
||||
|
||||
# Lazy-register Qwen3 parser; imported on first get_reasoning_parser("qwen3").
|
||||
ReasoningParserManager.register_lazy(
|
||||
"qwen3",
|
||||
"vllm.reasoning.qwen3_reasoning_parser",
|
||||
"Qwen3ReasoningParser",
|
||||
)
|
||||
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Abstract reasoning parser base classes for vLLM 0.6.3.
|
||||
Adapted from vllm-original/vllm/reasoning/abs_reasoning_parsers.py:
|
||||
- Removed vllm.entrypoints.mcp, vllm.utils.collection_utils, import_utils
|
||||
- DeltaMessage from vllm 0.6.3 protocol path
|
||||
- TokenizerLike -> AnyTokenizer
|
||||
- ReasoningParserManager: simplified eager + lazy registration
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
DeltaMessage = Any
|
||||
AnyTokenizer = Any
|
||||
|
||||
|
||||
class ReasoningParser:
|
||||
"""Abstract base for all reasoning parsers."""
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict:
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""Return True once the reasoning block has closed in input_ids."""
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
return self.is_reasoning_end(input_ids)
|
||||
|
||||
@abstractmethod
|
||||
def extract_content_ids(self, input_ids: list) -> list:
|
||||
"""Return token ids that belong to the content (post-reasoning) part."""
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
"""
|
||||
Split a complete model output into (reasoning_text, content_text).
|
||||
Either part may be None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Optional["DeltaMessage"]:
|
||||
"""
|
||||
Extract reasoning from a streaming delta.
|
||||
Returns a DeltaMessage with reasoning_content and/or content set,
|
||||
or None if this delta should be suppressed (control token).
|
||||
"""
|
||||
|
||||
|
||||
class BaseThinkingReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Base for parsers that use <start_token>...</end_token> delimiters.
|
||||
Subclasses define start_token / end_token properties.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def start_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def end_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError("Tokenizer must be passed to ReasoningParser.")
|
||||
if not self.start_token or not self.end_token:
|
||||
raise ValueError("start_token and end_token must be defined.")
|
||||
|
||||
self.start_token_id: Optional[int] = self.vocab.get(self.start_token)
|
||||
self.end_token_id: Optional[int] = self.vocab.get(self.end_token)
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}: could not find think tokens "
|
||||
f"'{self.start_token}'/'{self.end_token}' in tokenizer vocab."
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
for token_id in reversed(input_ids):
|
||||
if token_id == self.start_token_id:
|
||||
return False
|
||||
if token_id == self.end_token_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
return self.end_token_id in delta_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list) -> list:
|
||||
if self.end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
count = 0
|
||||
depth = 0
|
||||
for tid in token_ids:
|
||||
if tid == self.start_token_id:
|
||||
depth += 1
|
||||
elif tid == self.end_token_id:
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
elif depth > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
# Strip <think> if the model generated it (old-style template).
|
||||
parts = model_output.partition(self.start_token)
|
||||
model_output = parts[2] if parts[1] else parts[0]
|
||||
|
||||
if self.end_token not in model_output:
|
||||
return model_output, None
|
||||
reasoning, _, content = model_output.partition(self.end_token)
|
||||
return reasoning, content or None
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Optional["DeltaMessage"]:
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage as _DeltaMessage
|
||||
|
||||
# Suppress lone control tokens.
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] in (
|
||||
self.start_token_id, self.end_token_id
|
||||
):
|
||||
return None
|
||||
|
||||
start_in_prev = self.start_token_id in previous_token_ids
|
||||
start_in_delta = self.start_token_id in delta_token_ids
|
||||
end_in_prev = self.end_token_id in previous_token_ids
|
||||
end_in_delta = self.end_token_id in delta_token_ids
|
||||
|
||||
if start_in_prev:
|
||||
if end_in_delta:
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
reasoning = delta_text[:end_idx] if end_idx >= 0 else ""
|
||||
content = delta_text[end_idx + len(self.end_token):] if end_idx >= 0 else None
|
||||
return _DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
elif end_in_prev:
|
||||
return _DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return _DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
elif start_in_delta:
|
||||
if end_in_delta:
|
||||
start_idx = delta_text.find(self.start_token)
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
reasoning = delta_text[start_idx + len(self.start_token):end_idx]
|
||||
content = delta_text[end_idx + len(self.end_token):]
|
||||
return _DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
else:
|
||||
return _DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
else:
|
||||
return _DeltaMessage(content=delta_text)
|
||||
|
||||
|
||||
class ReasoningParserManager:
|
||||
"""
|
||||
Registry for ReasoningParser implementations.
|
||||
Supports eager and lazy registration.
|
||||
"""
|
||||
|
||||
_parsers: dict = {} # name -> class (eager)
|
||||
_lazy: dict = {} # name -> (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def register_module(cls, name: str, parser_cls: type) -> None:
|
||||
"""Eagerly register a ReasoningParser class."""
|
||||
if not issubclass(parser_cls, ReasoningParser):
|
||||
raise TypeError(f"{parser_cls} is not a ReasoningParser subclass.")
|
||||
cls._parsers[name] = parser_cls
|
||||
|
||||
@classmethod
|
||||
def register_lazy(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a parser for deferred import."""
|
||||
cls._lazy[name] = (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def get_reasoning_parser(cls, name: str) -> type:
|
||||
if name in cls._parsers:
|
||||
return cls._parsers[name]
|
||||
if name in cls._lazy:
|
||||
module_path, class_name = cls._lazy[name]
|
||||
mod = importlib.import_module(module_path)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
cls._parsers[name] = parser_cls
|
||||
return parser_cls
|
||||
registered = sorted(set(cls._parsers) | set(cls._lazy))
|
||||
raise KeyError(
|
||||
f"Reasoning parser '{name}' not found. "
|
||||
f"Available: {registered}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
return sorted(set(cls._parsers) | set(cls._lazy))
|
||||
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Reasoning parser for Qwen3 / Qwen3.5 / Qwen3.6 model family.
|
||||
Adapted from vllm-original/vllm/reasoning/qwen3_reasoning_parser.py.
|
||||
|
||||
The model uses <think>...</think> to wrap chain-of-thought output.
|
||||
For Qwen3.5+ the chat template injects <think> into the prompt, so only
|
||||
</think> appears in the generated tokens; older templates generate <think>
|
||||
themselves. Both styles are handled.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence, Any
|
||||
|
||||
from vllm.reasoning.abs_reasoning_parsers import (
|
||||
BaseThinkingReasoningParser,
|
||||
ReasoningParserManager,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer: Any, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return "<think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return "</think>"
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
# Strip <think> if the model generated it (old template / edge case).
|
||||
parts = model_output.partition(self.start_token)
|
||||
model_output = parts[2] if parts[1] else parts[0]
|
||||
|
||||
if self.end_token not in model_output:
|
||||
if not self.thinking_enabled:
|
||||
return None, model_output
|
||||
# Thinking enabled but output truncated before </think>.
|
||||
return model_output, None
|
||||
|
||||
reasoning, _, content = model_output.partition(self.end_token)
|
||||
return reasoning, content or None
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
token_ids = list(token_ids)
|
||||
if self.start_token_id in token_ids:
|
||||
# Old-style template: model generates <think> itself.
|
||||
# Use depth-counting from the base class.
|
||||
return super().count_reasoning_tokens(token_ids)
|
||||
elif self.end_token_id in token_ids:
|
||||
# New-style template (Qwen3.5+): <think> is injected into the
|
||||
# prompt, so output starts already inside the thinking block.
|
||||
# Every token before </think> is a reasoning token.
|
||||
return token_ids.index(self.end_token_id)
|
||||
else:
|
||||
# No </think> in output: either truncated (all reasoning)
|
||||
# or thinking disabled (none).
|
||||
return len(token_ids) if self.thinking_enabled else 0
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
):
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||
|
||||
if not self.thinking_enabled:
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# Strip <think> from delta if the model generates it itself.
|
||||
if self.start_token_id in delta_token_ids:
|
||||
start_idx = delta_text.find(self.start_token)
|
||||
if start_idx >= 0:
|
||||
delta_text = delta_text[start_idx + len(self.start_token):]
|
||||
|
||||
if self.end_token_id in delta_token_ids:
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
if end_idx >= 0:
|
||||
reasoning = delta_text[:end_idx]
|
||||
content = delta_text[end_idx + len(self.end_token):]
|
||||
if not reasoning and not content:
|
||||
return None
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
return None
|
||||
|
||||
if not delta_text:
|
||||
return None
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
|
||||
# Register immediately when this module is imported.
|
||||
ReasoningParserManager.register_module("qwen3", Qwen3ReasoningParser)
|
||||
Reference in New Issue
Block a user