Compare commits
10 Commits
2d1ef50992
...
72aa7e690a
| Author | SHA1 | Date | |
|---|---|---|---|
| 72aa7e690a | |||
| b5806731e0 | |||
| 47a4d9e72a | |||
| 3b8a567e9e | |||
| 50e3a05fb0 | |||
| 629f878c28 | |||
| 365da18436 | |||
| 4ab36b51d5 | |||
| d972854fb7 | |||
| c2de1c83b0 |
132
README.md
132
README.md
@@ -1,118 +1,46 @@
|
||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化)
|
||||
|
||||
本项目是为**天数智芯-天垓100**加速卡深度优化的高性能文本生成推理引擎,基于开源 **vLLM** 框架进行架构级适配与增强,率先实现对 **Qwen3 系列**等最新大模型的高效支持。通过引入 **Prefix Caching**、PagedAttention 等先进优化技术,显著提升吞吐与响应速度,同时提供标准 **OpenAI 兼容 API 接口**,便于无缝集成现有应用生态。
|
||||
|
||||
## 支持模型
|
||||
|
||||
- **Qwen3**
|
||||
- **Llama3**
|
||||
- **DeepSeek-R1-Distill**
|
||||
- 其他兼容 vLLM 的 HuggingFace 模型(持续扩展中)
|
||||
|
||||
> 模型下载地址:[https://modelscope.cn/models/Qwen](https://modelscope.cn/models/Qwen)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. 模型下载
|
||||
|
||||
从 ModelScope 下载所需模型(以 Qwen2.5-7B-Instruct 为例):
|
||||
|
||||
```bash
|
||||
modelscope download --model qwen/Qwen2.5-7B-Instruct README.md --local_dir /mnt/models/Qwen2.5-7B-Instruct
|
||||
```
|
||||
|
||||
> ⚠️ 请确保模型路径在后续 Docker 启动时正确挂载。
|
||||
|
||||
---
|
||||
|
||||
### 2. 拉取并构建 Docker 镜像
|
||||
|
||||
我们提供已预装天垓100驱动与vLLM优化版本的Docker镜像:
|
||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-35B-A3B)
|
||||
|
||||
```
|
||||
# 本地构建
|
||||
docker build -t enginex-iluvatar-vllm:bi100 -f Dockerfile .
|
||||
docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile .
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. 启动服务容器
|
||||
启动容器镜像
|
||||
|
||||
```bash
|
||||
docker run -it --rm -p 8000:80 \
|
||||
--name vllm-iluvatar \
|
||||
-v /mnt/models/Qwen2.5-7B-Instruct:/model:ro \
|
||||
--privileged \
|
||||
-e TENSOR_PARALLEL_SIZE=1 \
|
||||
-e PREFIX_CACHING=true \
|
||||
-e MAX_MODEL_LEN=10000 \
|
||||
enginex-iluvatar-vllm:bi100
|
||||
下载Qwen3.6-35B-A3B模型,并且需要将模型的config.json文件中architectures字段改成
|
||||
```json
|
||||
"architectures": [
|
||||
"Qwen3_5MoeForCausalLM"
|
||||
]
|
||||
```
|
||||
|
||||
> ✅ 参数说明:
|
||||
> - `PREFIX_CACHING=true`: 启用 Prefix Caching 优化,显著提升多请求共享前缀的推理效率
|
||||
> - `MAX_MODEL_LEN=10000`: 支持长上下文推理
|
||||
> - `--privileged`: 确保天垓100设备可见
|
||||
|
||||
---
|
||||
|
||||
## 4. 测试服务(使用 OpenAI 兼容接口)
|
||||
|
||||
服务启动后,可通过标准 OpenAI SDK 或 `curl` 进行测试。
|
||||
|
||||
### 示例:文本生成请求
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
docker run -dit --network=host --ipc=host \
|
||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \
|
||||
-v /mnt/disk1/models/Qwen3.6-35B-A3B:/model:ro --entrypoint=python3 \
|
||||
-e CUDA_VISIBLE_DEVICES=4,5,6,7 -e VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 \
|
||||
enginex-iluvatar-vllm:bi100-qwen3.6 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model /model --port 1111 --served-model-name llm \
|
||||
--max-model-len 100000 --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
|
||||
--max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
||||
--max-num-batched-tokens 4096 --enable-chunked-prefill \
|
||||
--max-seq-len-to-capture 32768 --enable-auto-tool-choice \
|
||||
--tool-call-parser qwen3_coder --reasoning-parser qwen3
|
||||
```
|
||||
|
||||
请求
|
||||
```bash
|
||||
curl http://localhost:1111/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "qwen3-8b",
|
||||
"model": "llm",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "请用中文介绍一下上海的特点。"}
|
||||
{"role": "user", "content": "Can you tell me the story of Snow White?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
### 使用 OpenAI Python SDK(需安装 `openai>=1.0`)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="http://localhost:8000/v1", api_key="none")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="qwen3-8b",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "请简要介绍杭州的特色文化。"}
|
||||
],
|
||||
max_tokens=512,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试结果对比(A100 vs 天垓100)
|
||||
|
||||
### 测试数据集
|
||||
|
||||
[chat_dataset_v0.json](chat_dataset_v0.json)
|
||||
|
||||
### 测试结果
|
||||
|
||||
在相同模型和输入条件下,测试平均输出速度(单位:字每秒),结果如下:
|
||||
|
||||
| 模型 | 天垓100 输出速度 | Nvidia A100 输出速度 |
|
||||
|--------|--------------------------|-------------------------------|
|
||||
| Qwen2.5-7B-Instruct | 36.8 | 112.4 |
|
||||
| Qwen2.5-1.5B-Instruct-AWQ | 72.4 | 100.8 |
|
||||
| Qwen/Qwen1.5-32B-Chat | 12.4 | 55.7 |
|
||||
|
||||
```
|
||||
@@ -1,42 +0,0 @@
|
||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-27B)
|
||||
|
||||
```
|
||||
# 本地构建
|
||||
docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile .
|
||||
```
|
||||
|
||||
|
||||
启动容器镜像
|
||||
|
||||
下载Qwen3.6-27B模型,并且需要将模型的config.json文件中architectures字段改成
|
||||
```json
|
||||
"architectures": [
|
||||
"Qwen3_5ForCausalLM"
|
||||
]
|
||||
```
|
||||
|
||||
```bash
|
||||
docker run -dit --network=host --ipc=host \
|
||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \
|
||||
--name vllm-iluvatar \
|
||||
-v /mnt/models/Qwen3.6-27B:/model:ro --entrypoint=python3 \
|
||||
enginex-iluvatar-vllm:bi100 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model /model --port 1111 --served-model-name llm \
|
||||
--max-model-len 10000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95
|
||||
```
|
||||
|
||||
请求
|
||||
```bash
|
||||
curl http://localhost:1111/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llm",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Can you tell me the story of Snow White?"}
|
||||
],
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
33
computility-run.yaml
Normal file
33
computility-run.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
gpu_num: 4
|
||||
command:
|
||||
- python3
|
||||
- -m
|
||||
- vllm.entrypoints.openai.api_server
|
||||
- --model
|
||||
- /model
|
||||
- --served-model-name
|
||||
- llm
|
||||
- --max-model-len
|
||||
- '100000'
|
||||
- --gpu-memory-utilization
|
||||
- '0.95'
|
||||
- --trust-remote-code
|
||||
- -tp
|
||||
- '4'
|
||||
- --max-num-seqs
|
||||
- '1'
|
||||
- --disable-log-requests
|
||||
- --disable-frontend-multiprocessing
|
||||
- --max-num-batched-tokens
|
||||
- '4096'
|
||||
- --enable-chunked-prefill
|
||||
- --max-seq-len-to-capture
|
||||
- '32768'
|
||||
- --enable-auto-tool-choice
|
||||
- --tool-call-parser
|
||||
- qwen3_coder
|
||||
- --reasoning-parser
|
||||
- qwen3
|
||||
env:
|
||||
- name: VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
value: 3600
|
||||
595
qwen3_6_scripts/api_server.py
Normal file
595
qwen3_6_scripts/api_server.py
Normal file
@@ -0,0 +1,595 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import regex as re
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10.)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
global prometheus_multiproc_dir
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ[
|
||||
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await mp_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not engine_process.is_alive():
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start") from None
|
||||
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
mp_engine_client.close()
|
||||
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
else:
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
if token := envs.VLLM_API_KEY or args.api_key:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
root_path = "" if args.root_path is None else args.root_path
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported)
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
raise ValueError(f"Invalid middleware {middleware}. "
|
||||
f"Must be a function or a class.")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=getattr(args, 'reasoning_parser', None))
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
reasoning_parser = getattr(args, 'reasoning_parser', None)
|
||||
if reasoning_parser:
|
||||
valid_reasoning = ReasoningParserManager.list_registered()
|
||||
if reasoning_parser not in valid_reasoning:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning)} }})")
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", args.port))
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
fd=sock.fileno(),
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
601
qwen3_6_scripts/chat_utils.py
Normal file
601
qwen3_6_scripts/chat_utils.py
Normal file
@@ -0,0 +1,601 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||
Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AudioURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""
|
||||
Either a URL of the audio or a data URL with base64 encoded audio data.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
audio_url: Required[AudioURL]
|
||||
|
||||
type: Required[Literal["audio_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||
|
||||
type: Required[str]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentPartParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
reasoning_content: Optional[str]
|
||||
"""Reasoning / thinking content for assistant messages (vLLM extension).
|
||||
When present in a previous assistant turn, it is rendered as
|
||||
<think>...</think> before the main content so the model sees its own
|
||||
chain-of-thought in subsequent turns."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Optional[str]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the function to call"""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
reasoning_content: Optional[str]
|
||||
"""Reasoning / thinking content for assistant messages.
|
||||
Passed directly to the chat template (Qwen3 reads message.reasoning_content
|
||||
natively) instead of being manually wrapped in <think>...</think>."""
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
of multi-modal items in a given request does not exceed the configured
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
|
||||
self._items: List[_T] = []
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
return tokenizer.decode(token_index)
|
||||
|
||||
def _placeholder_str(self, modality: ModalityStr,
|
||||
current_count: int) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
if model_type == "minicpmv":
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||
"pixtral"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type == "qwen":
|
||||
return f"Picture {current_count}: <img></img>"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if model_type == "molmo":
|
||||
return ""
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in items:
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
else:
|
||||
mm_lists[mm_key].append(mm_item)
|
||||
|
||||
# Unpack any single item lists for models that don't expect multiple.
|
||||
return {
|
||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._items.append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
return self._combine(self._items) if self._items else None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items:
|
||||
items = await asyncio.gather(*self._items)
|
||||
return self._combine(items)
|
||||
|
||||
return None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
|
||||
|
||||
class BaseMultiModalContentParser(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
return
|
||||
|
||||
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||
raise FileNotFoundError(
|
||||
"the supplied chat template path doesn't exist")
|
||||
|
||||
elif isinstance(chat_template, str):
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template
|
||||
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
||||
raise ValueError(
|
||||
f"The supplied chat template string ({chat_template}) "
|
||||
f"appears path-like, but doesn't exist!")
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{type(chat_template)} is not a valid chat template type")
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||
if chat_template is None:
|
||||
return None
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
resolved_chat_template = f.read()
|
||||
except OSError as e:
|
||||
if isinstance(chat_template, Path):
|
||||
raise
|
||||
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template for c in JINJA_CHARS):
|
||||
msg = (f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be "
|
||||
f"opened. Reason: {e}")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
return resolved_chat_template
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders: List[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
|
||||
# For any existing placeholder in the text prompt, we leave it as is
|
||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||
|
||||
if placeholder_counts[placeholder] < 0:
|
||||
raise ValueError(
|
||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||
"actual multimodal data items.")
|
||||
|
||||
missing_placeholders.extend([placeholder] *
|
||||
placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: For now we always add missing placeholders at the front of
|
||||
# the prompt. This may change to be customizable in the future.
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
texts: List[str] = []
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
keep_multimodal_content = \
|
||||
mm_tracker._model_config.hf_config.model_type in \
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||
|
||||
has_image = False
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = _TextParser(part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
image_url = _ImageParser(part)["image_url"]
|
||||
|
||||
if image_url.get("detail", "auto") != "auto":
|
||||
logger.warning(
|
||||
"'image_url.detail' is currently not supported and "
|
||||
"will be ignored.")
|
||||
|
||||
mm_parser.parse_image(image_url["url"])
|
||||
has_image = True
|
||||
elif part_type == "audio_url":
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
elif part_type == "refusal":
|
||||
text = _RefusalParser(part)["refusal"]
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
if keep_multimodal_content:
|
||||
text_prompt = "\n".join(texts)
|
||||
role_content = [{'type': 'text', 'text': text_prompt}]
|
||||
|
||||
if has_image:
|
||||
role_content = [{'type': 'image'}] + role_content
|
||||
return [ConversationMessage(role=role,
|
||||
content=role_content)] # type: ignore
|
||||
else:
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(
|
||||
mm_placeholder_counts, text_prompt)
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
|
||||
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
|
||||
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
if content is None:
|
||||
content = []
|
||||
elif isinstance(content, str):
|
||||
content = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||
]
|
||||
|
||||
result = _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
if role == 'assistant':
|
||||
parsed_msg = _AssistantParser(message)
|
||||
|
||||
if "tool_calls" in parsed_msg:
|
||||
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||
|
||||
# Pass reasoning content as a dedicated field so the chat template
|
||||
# can render it natively (Qwen3: message.reasoning_content branch).
|
||||
# Accept both "reasoning" (new vllm) and "reasoning_content" (ours).
|
||||
reasoning = (message.get("reasoning") # type: ignore[arg-type]
|
||||
or message.get("reasoning_content")) # type: ignore[arg-type]
|
||||
if reasoning and isinstance(reasoning, str):
|
||||
result_msg["reasoning_content"] = reasoning
|
||||
|
||||
elif role == "tool":
|
||||
parsed_msg = _ToolParser(message)
|
||||
if "tool_call_id" in parsed_msg:
|
||||
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
|
||||
|
||||
if "name" in message and isinstance(message["name"], str):
|
||||
result_msg["name"] = message["name"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in messages:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for item in message["tool_calls"]:
|
||||
item["function"]["arguments"] = json.loads(
|
||||
item["function"]["arguments"])
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: List[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one.")
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
if chat_template is not None:
|
||||
logger.warning(
|
||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||
if "add_generation_prompt" in kwargs:
|
||||
logger.warning(
|
||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored.")
|
||||
if "continue_final_message" in kwargs:
|
||||
logger.warning(
|
||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored.")
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
261
qwen3_6_scripts/cli_args.py
Normal file
261
qwen3_6_scripts/cli_args.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--host",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument(
|
||||
"--uvicorn-log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||
help="log level for uvicorn")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--api-key",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): 'name=path' "
|
||||
"Example (new format): "
|
||||
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||
"\"base_model_name\": \"id\"}'")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file")
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=nullable_str,
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
"We accept multiple --middleware arguments. "
|
||||
"The value should be an import path. "
|
||||
"If a function is provided, vLLM will add it to the server "
|
||||
"using @app.middleware('http'). "
|
||||
"If a class is provided, vLLM will add it to the server "
|
||||
"using app.add_middleware(). ")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When --max-logprobs is specified, represents single tokens as "
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||
"are not JSON-encodable can be identified.")
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser to split <think>...</think> content into "
|
||||
"reasoning_content vs content in the response. "
|
||||
"Supported: qwen3")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-fastapi-docs",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
@@ -309,28 +309,23 @@ class PagedAttention:
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch prefix-attention with query-chunking (no Triton).
|
||||
"""Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
|
||||
|
||||
For each sequence, gathers the context KV from the paged KV cache,
|
||||
concatenates with the current-chunk K/V, then computes scaled-dot-
|
||||
product attention with a causal mask.
|
||||
Memory complexity: O(q_len), independent of kv_len.
|
||||
With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
|
||||
per layer ≈ 96 MB regardless of context length.
|
||||
|
||||
Memory optimisation — GQA-aware Q-tiling
|
||||
-----------------------------------------
|
||||
Two complementary tricks keep peak activation memory well below 1 GB
|
||||
even for 100K context on TP=4 (kv_h=1, q_h=6):
|
||||
Algorithm: Flash Attention online softmax.
|
||||
Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
|
||||
K-tiles. For each tile a running (m, l, o) accumulator is updated —
|
||||
the [q_len × kv_len] attention matrix is NEVER materialised in full.
|
||||
|
||||
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
|
||||
resolution and GQA grouping is handled via 4D reshape+broadcast
|
||||
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
|
||||
vs the old expand-then-float32 approach:
|
||||
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
|
||||
New: [1, 100K, 256] fp32 = 98 MB each for K and V
|
||||
|
||||
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
|
||||
is bounded to ~148 MB at 100K instead of growing with q_len.
|
||||
|
||||
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
|
||||
Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
|
||||
q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
|
||||
o_acc same shape 24 MB (held all tiles)
|
||||
s same shape 24 MB (per tile, freed before exp_s)
|
||||
exp_s same shape 24 MB (per tile, brief overlap with s)
|
||||
Peak ≈ 96 MB (s and exp_s briefly coexist during update).
|
||||
|
||||
Shapes
|
||||
------
|
||||
@@ -344,29 +339,24 @@ class PagedAttention:
|
||||
seq_lens_tensor: [batch_size] total length (context + query)
|
||||
context_lens : [batch_size] tokens already in KV cache
|
||||
"""
|
||||
# Memory-efficient query-chunked attention.
|
||||
# Key optimisation: do NOT expand KV heads for GQA before materialising
|
||||
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
|
||||
# resolution saves ~6× memory vs expanding to q_h first:
|
||||
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
|
||||
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
|
||||
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
|
||||
# reshaping, so no extra tensors are created.
|
||||
try:
|
||||
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
|
||||
# Paged-block tiles for context phase.
|
||||
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
|
||||
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
|
||||
# Same tile size reused for the current-chunk phase.
|
||||
_BLOCKS_PER_TILE = 32
|
||||
|
||||
batch_size = seq_lens_tensor.shape[0]
|
||||
num_q_heads = query.shape[1]
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
head_dim = query.shape[2]
|
||||
gqa_ratio = num_q_heads // num_kv_heads
|
||||
|
||||
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
output = torch.empty_like(query)
|
||||
orig_dtype = query.dtype
|
||||
block_size = value_cache.shape[3]
|
||||
tile_sz = _BLOCKS_PER_TILE * block_size
|
||||
scale = head_dim ** -0.5
|
||||
orig_dtype = query.dtype
|
||||
output = torch.empty_like(query)
|
||||
dev = query.device
|
||||
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
@@ -374,96 +364,147 @@ class PagedAttention:
|
||||
q_end = int(query_start_loc[i + 1].item())
|
||||
q_len = q_end - q_start
|
||||
|
||||
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim]
|
||||
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim]
|
||||
q_i = query[q_start:q_end] # [q_len, q_h, d]
|
||||
k_i = key [q_start:q_end] # [q_len, kv_h, d]
|
||||
v_i = value[q_start:q_end]
|
||||
|
||||
# --- Build full K/V (context from cache + current chunk) ----
|
||||
# Q reshaped and scaled once; held for all K-tiles.
|
||||
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
|
||||
q_seq = (q_i.permute(1, 0, 2)
|
||||
.float()
|
||||
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
|
||||
.mul_(scale))
|
||||
|
||||
# Flash-Attention online-softmax accumulators.
|
||||
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
|
||||
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
|
||||
m = torch.full((num_kv_heads, gqa_ratio, q_len),
|
||||
float('-inf'), dtype=torch.float32, device=dev)
|
||||
l = torch.zeros_like(m)
|
||||
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
|
||||
dtype=torch.float32, device=dev)
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Phase 1 — context tokens (positions 0 … ctx_len-1).
|
||||
#
|
||||
# Every context key has absolute position < ctx_len; every
|
||||
# query has position ≥ ctx_len. k_pos < q_pos is always True
|
||||
# → no causal mask needed for pure context tiles.
|
||||
# --------------------------------------------------------------
|
||||
if ctx_len > 0:
|
||||
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
|
||||
blk_ids = block_tables[i, :num_ctx_blocks]
|
||||
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
|
||||
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
|
||||
blk_ids = block_tables[i, tile_blk:blk_end]
|
||||
|
||||
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
||||
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len]
|
||||
k_ctx = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
||||
# Gather K/V for this tile.
|
||||
# key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
||||
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
||||
k_tile = (key_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))
|
||||
v_tile = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))
|
||||
|
||||
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
||||
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
|
||||
v_ctx = (value_cache[blk_ids]
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
||||
# Trim padding in the last block of the tile.
|
||||
valid = (min(blk_end * block_size, ctx_len)
|
||||
- tile_blk * block_size)
|
||||
k_tile = k_tile[:valid] # [valid, kv_h, d]
|
||||
v_tile = v_tile[:valid]
|
||||
|
||||
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
|
||||
v_full = torch.cat([v_ctx, v_i], dim=0)
|
||||
del k_ctx, v_ctx
|
||||
else:
|
||||
k_full = k_i
|
||||
v_full = v_i
|
||||
# k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
|
||||
# v_t: [kv_h, 1, valid, d]
|
||||
k_t = (k_tile.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.transpose(-1, -2)
|
||||
.float())
|
||||
v_t = (v_tile.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.float())
|
||||
del k_tile, v_tile
|
||||
|
||||
kv_len = k_full.shape[0] # ctx_len + q_len
|
||||
# Scores: [kv_h, gqa, q_len, valid]
|
||||
s = torch.matmul(q_seq, k_t)
|
||||
del k_t
|
||||
# No causal mask: all context keys precede all queries.
|
||||
|
||||
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
|
||||
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
|
||||
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
|
||||
# k cast inline) so softmax precision is unaffected.
|
||||
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
|
||||
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del k_full
|
||||
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
||||
del v_full
|
||||
# Online softmax update — Flash-Attention Algorithm 1.
|
||||
# exp_s = s - new_max (in-place exp after del s)
|
||||
m_blk = s.amax(dim=-1)
|
||||
m_new = torch.maximum(m, m_blk)
|
||||
exp_s = s - m_new.unsqueeze(-1)
|
||||
del s
|
||||
exp_s.exp_()
|
||||
corr = torch.exp(m - m_new)
|
||||
m.copy_(m_new)
|
||||
del m_blk, m_new
|
||||
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||
o.mul_(corr.unsqueeze(-1)).add_(
|
||||
torch.matmul(exp_s, v_t))
|
||||
del exp_s, v_t, corr
|
||||
|
||||
# k_pos used for causal mask: shape [kv_len]
|
||||
k_pos = torch.arange(kv_len, device=query.device)
|
||||
# --------------------------------------------------------------
|
||||
# Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
|
||||
#
|
||||
# Causal mask: query at relative position j sees key at relative
|
||||
# position k only when k ≤ j. Tiles of tile_sz tokens each.
|
||||
# --------------------------------------------------------------
|
||||
for kc_start in range(0, q_len, tile_sz):
|
||||
kc_end = min(kc_start + tile_sz, q_len)
|
||||
kc_len = kc_end - kc_start
|
||||
|
||||
# --- Query-chunked attention with lazy GQA grouping ----------
|
||||
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
|
||||
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
|
||||
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
|
||||
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
|
||||
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
|
||||
qc = qc_end - qc_start
|
||||
k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
|
||||
v_blk = v_i[kc_start:kc_end]
|
||||
|
||||
# [kv_h, gqa_ratio, qc, d]
|
||||
q_t_chunk = (q_i[qc_start:qc_end]
|
||||
.permute(1, 0, 2) # [q_h, qc, d]
|
||||
.float()
|
||||
.view(num_kv_heads, gqa_ratio, qc, head_dim))
|
||||
k_t = (k_blk.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.transpose(-1, -2)
|
||||
.float()) # [kv_h, 1, d, kc_len]
|
||||
v_t = (v_blk.permute(1, 0, 2)
|
||||
.unsqueeze(1)
|
||||
.float()) # [kv_h, 1, kc_len, d]
|
||||
|
||||
# [kv_h, gqa_ratio, qc, kv_len]
|
||||
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
|
||||
# Cast k slice to fp32 inline; the temporary is freed after matmul.
|
||||
attn_w = torch.matmul(q_t_chunk * scale,
|
||||
k_t.unsqueeze(1).transpose(-1, -2).float())
|
||||
s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
|
||||
del k_t
|
||||
|
||||
# Causal mask for this sub-chunk:
|
||||
# query absolute position = ctx_len + qc_start..qc_end-1
|
||||
qc_q_pos = torch.arange(qc_start, qc_end,
|
||||
device=query.device)
|
||||
# mask[j, k] = True → future key, block it
|
||||
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
|
||||
attn_w.masked_fill_(
|
||||
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||
# Causal mask: key at (kc_start+k) must not exceed query j.
|
||||
k_rel = torch.arange(kc_start, kc_end, device=dev)
|
||||
q_rel = torch.arange(q_len, device=dev)
|
||||
mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
|
||||
s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||
del mask, k_rel, q_rel
|
||||
|
||||
# In-place numerically stable softmax — avoids allocating a
|
||||
# new 150 MB tensor (same size as attn_w) that torch.softmax
|
||||
# would create, which exhausts the fragmented GPU pool.
|
||||
attn_w -= attn_w.amax(dim=-1, keepdim=True)
|
||||
attn_w.exp_()
|
||||
attn_w /= attn_w.sum(dim=-1, keepdim=True)
|
||||
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
|
||||
out_c = torch.matmul(attn_w,
|
||||
v_t.unsqueeze(1).float())
|
||||
# reshape to [q_h, qc, d] then [qc, q_h, d]
|
||||
out_c = out_c.view(num_q_heads, qc, head_dim)
|
||||
# Online softmax update (identical to context phase).
|
||||
m_blk = s.amax(dim=-1)
|
||||
m_new = torch.maximum(m, m_blk)
|
||||
exp_s = s - m_new.unsqueeze(-1)
|
||||
del s
|
||||
exp_s.exp_()
|
||||
corr = torch.exp(m - m_new)
|
||||
m.copy_(m_new)
|
||||
del m_blk, m_new
|
||||
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||
o.mul_(corr.unsqueeze(-1)).add_(
|
||||
torch.matmul(exp_s, v_t))
|
||||
del exp_s, v_t, corr
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# Finalize: normalize running output by normalization factor.
|
||||
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
|
||||
# --------------------------------------------------------------
|
||||
o.div_(l.unsqueeze(-1))
|
||||
output[q_start:q_end] = (
|
||||
o.view(num_q_heads, q_len, head_dim)
|
||||
.permute(1, 0, 2)
|
||||
.to(orig_dtype)
|
||||
)
|
||||
|
||||
output[q_start + qc_start : q_start + qc_end] = (
|
||||
out_c.to(orig_dtype).permute(1, 0, 2))
|
||||
except Exception as e:
|
||||
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True)
|
||||
print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
|
||||
file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
raise
|
||||
return output
|
||||
|
||||
@@ -10,18 +10,11 @@
|
||||
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
|
||||
#
|
||||
# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
|
||||
#
|
||||
# Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism:
|
||||
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
|
||||
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
|
||||
# --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \
|
||||
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
||||
# --max-num-batched-tokens 50000
|
||||
|
||||
# Recommended server start command for TP=4 support 100K, need chunked prefill
|
||||
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
|
||||
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
|
||||
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \
|
||||
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
|
||||
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
||||
# --max-num-batched-tokens 4096 --enable-chunked-prefill
|
||||
|
||||
@@ -29,13 +22,14 @@
|
||||
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
|
||||
# (standard Triton 2.3.1 PTX is not supported by the corex runtime either).
|
||||
# Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which
|
||||
# also implements query-chunking (_ATTN_Q_CHUNK=256) to keep peak attention
|
||||
# memory at O(256 × kv_len) instead of O(q_len × kv_len).
|
||||
cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
||||
# utilizes K-tiling techniques, and also have _forward_decode_pytorch to bypass kernel
|
||||
# when context length is high
|
||||
cp ./paged_attn.py /usr/local/corex/lib/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
||||
|
||||
# --- transformers: Qwen3_5 tokenizer / model files --------------------------
|
||||
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||
cp -r ./qwen3_5_moe /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||
python3 ./patch_transformers_qwen3_5.py
|
||||
|
||||
# --- vllm model: Qwen3.6-27B (Qwen3_5 arch) --------------------------------
|
||||
@@ -43,10 +37,36 @@ cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_execut
|
||||
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/qwen3_5.py
|
||||
python3 ./patch_vllm_qwen3_5.py
|
||||
|
||||
# --- sequence.py: fix completion_tokens inflation under chunked prefill ------
|
||||
# Bug: get_output_token_ids_to_return(delta=True) with num_new_tokens=0
|
||||
# returns _cached_all_token_ids[-0:] == [0:] (the ENTIRE prompt+output list).
|
||||
# Each prefill chunk step adds prompt_len to previous_num_tokens, so a 10K
|
||||
# prompt processed in 3 chunks inflates completion_tokens by ~30K.
|
||||
cp ./sequence.py /usr/local/corex/lib/python3/dist-packages/vllm/sequence.py
|
||||
|
||||
# --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------
|
||||
# Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py.
|
||||
# Required because head_dim=256 > 128 and ixformer flash attention either
|
||||
# crashes (is_causal=True) or produces wrong output (attn_mask path).
|
||||
# The fallback uses query_start_loc to derive actual query lengths, so it
|
||||
# works correctly during profiling runs with chunked-prefill-style batches.
|
||||
# also bypasses auto chunked prefill on
|
||||
python3 ./patch_xformers_sdpa_seq.py
|
||||
|
||||
# --- tool parser: Qwen3 XML tool call format ---------------------------------
|
||||
# Registers "qwen3_coder" parser for Qwen3.6 XML-style tool calls:
|
||||
# <tool_call><function=name><parameter=key>\nvalue\n</parameter></function></tool_call>
|
||||
# Use at server start: --tool-call-parser qwen3_coder --enable-auto-tool-choice
|
||||
cp ./qwen3coder_tool_parser.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
||||
python3 ./patch_vllm_tool_parser.py
|
||||
|
||||
# --- reasoning parser: Qwen3 <think>...</think> split ------------------------
|
||||
# Adds --reasoning-parser qwen3 support.
|
||||
# Routes thinking tokens to reasoning_content, rest to content in the delta.
|
||||
# Works together with --tool-call-parser qwen3_coder (think → tool call flow).
|
||||
cp -r ./reasoning /usr/local/corex/lib/python3/dist-packages/vllm/
|
||||
cp ./protocol.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/protocol.py
|
||||
cp ./cli_args.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/cli_args.py
|
||||
cp ./serving_chat.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/serving_chat.py
|
||||
cp ./api_server.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/api_server.py
|
||||
cp ./chat_utils.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/chat_utils.py
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
Patches transformers 4.55.3 to register the qwen3_5 model type.
|
||||
Patches transformers 4.55.3 to register qwen3_5 and qwen3_5_moe model types.
|
||||
|
||||
Deploy steps on the remote machine:
|
||||
1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5
|
||||
2. python3 modified_scripts/patch_transformers_qwen3_5.py
|
||||
1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5
|
||||
2. cp -r modified_scripts/qwen3_5_moe /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5_moe
|
||||
3. python3 modified_scripts/patch_transformers_qwen3_5.py
|
||||
|
||||
Target: pip-installed transformers at /usr/local/lib/python3.10/site-packages/transformers/
|
||||
(Not the corex pre-installed path at /usr/local/corex/lib64/python3/dist-packages/)
|
||||
@@ -40,24 +41,23 @@ def patch_file(path, replacements):
|
||||
def main():
|
||||
print(f"=== Patching {AUTO_CONFIG} ===")
|
||||
patch_file(AUTO_CONFIG, [
|
||||
# CONFIG_MAPPING_NAMES: insert qwen3_5 right after qwen3
|
||||
# CONFIG_MAPPING_NAMES: insert qwen3_5 + qwen3_5_moe right after qwen3
|
||||
(
|
||||
'("qwen3", "Qwen3Config"),',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n ("qwen3_5_moe", "Qwen3_5MoeConfig"),',
|
||||
),
|
||||
# Some versions don't have trailing comma — handle that too
|
||||
(
|
||||
'("qwen3", "Qwen3Config")\n',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n',
|
||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n ("qwen3_5_moe", "Qwen3_5MoeConfig"),\n',
|
||||
),
|
||||
# MODEL_NAMES_MAPPING (model_type -> human readable name, used by docstring generator)
|
||||
# MODEL_NAMES_MAPPING (model_type -> human readable name)
|
||||
(
|
||||
'("qwen3", "Qwen3"),',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n ("qwen3_5_moe", "Qwen3_5_MoE"),',
|
||||
),
|
||||
(
|
||||
'("qwen3", "Qwen3")\n',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n',
|
||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n ("qwen3_5_moe", "Qwen3_5_MoE"),\n',
|
||||
),
|
||||
])
|
||||
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
patch_file(MODELS_INIT, [
|
||||
(
|
||||
"from .qwen3 import *\n",
|
||||
"from .qwen3 import *\n from .qwen3_5 import *\n",
|
||||
"from .qwen3 import *\n from .qwen3_5 import *\n from .qwen3_5_moe import *\n",
|
||||
),
|
||||
])
|
||||
|
||||
@@ -74,19 +74,39 @@ def main():
|
||||
try:
|
||||
import importlib.util, types
|
||||
|
||||
# Quick smoke-test: import the config class directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"configuration_qwen3_5",
|
||||
def _load_config_mod(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
mod.__package__ = ".".join(module_name.split(".")[:-1])
|
||||
pkg = sys.modules.setdefault("transformers", types.ModuleType("transformers"))
|
||||
pkg.__path__ = [TRANSFORMERS_ROOT]
|
||||
cu = sys.modules.setdefault(
|
||||
"transformers.configuration_utils", types.ModuleType("transformers.configuration_utils"))
|
||||
class _PC:
|
||||
def __init__(self, **kwargs): pass
|
||||
cu.PretrainedConfig = _PC
|
||||
for sub in ("transformers.models", f"transformers.models.{module_name.split('.')[-2]}"):
|
||||
m = sys.modules.setdefault(sub, types.ModuleType(sub))
|
||||
m.__path__ = [TRANSFORMERS_ROOT]
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
mod27 = _load_config_mod(
|
||||
"transformers.models.qwen3_5.configuration_qwen3_5",
|
||||
f"{TRANSFORMERS_ROOT}/models/qwen3_5/configuration_qwen3_5.py",
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
# Provide minimal parent package stubs so relative imports resolve
|
||||
pkg = types.ModuleType("transformers")
|
||||
pkg.__path__ = [TRANSFORMERS_ROOT]
|
||||
sys.modules.setdefault("transformers", pkg)
|
||||
spec.loader.exec_module(mod)
|
||||
cfg = mod.Qwen3_5Config()
|
||||
print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})")
|
||||
cfg = mod27.Qwen3_5Config()
|
||||
print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})")
|
||||
|
||||
mod35 = _load_config_mod(
|
||||
"transformers.models.qwen3_5_moe.configuration_qwen3_5_moe",
|
||||
f"{TRANSFORMERS_ROOT}/models/qwen3_5_moe/configuration_qwen3_5_moe.py",
|
||||
)
|
||||
moe_cfg = mod35.Qwen3_5MoeConfig()
|
||||
print(f" Qwen3_5MoeConfig() smoke-test OK (model_type={moe_cfg.model_type})")
|
||||
t = moe_cfg.text_config
|
||||
print(f" num_experts={t.num_experts}, top_k={t.num_experts_per_tok}, "
|
||||
f"shared={t.shared_expert_intermediate_size}, layers={t.num_hidden_layers}")
|
||||
except Exception as e:
|
||||
print(f" [warn] smoke-test failed (may be fine at runtime): {e}")
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ def main():
|
||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),',
|
||||
' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n'
|
||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),\n'
|
||||
' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),',
|
||||
' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),\n'
|
||||
' "Qwen3_5MoeForCausalLM": ("qwen3_5", "Qwen3_5MoeForCausalLM"),',
|
||||
),
|
||||
])
|
||||
|
||||
@@ -61,11 +62,13 @@ def main():
|
||||
spec.loader.exec_module(mod)
|
||||
cls = mod.Qwen3_5ForCausalLM
|
||||
print(f" Qwen3_5ForCausalLM found: {cls}")
|
||||
cls_moe = mod.Qwen3_5MoeForCausalLM
|
||||
print(f" Qwen3_5MoeForCausalLM found: {cls_moe}")
|
||||
except Exception as e:
|
||||
print(f" [warn] verification failed (may be OK at runtime): {e}")
|
||||
|
||||
print("\nDone. Remember to:")
|
||||
print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM']")
|
||||
print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM'] or ['Qwen3_5MoEForCausalLM']")
|
||||
print(" 2. Run patch_transformers_qwen3_5.py if not already done")
|
||||
|
||||
|
||||
|
||||
79
qwen3_6_scripts/patch_vllm_tool_parser.py
Normal file
79
qwen3_6_scripts/patch_vllm_tool_parser.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Patches vLLM 0.6.3 to register Qwen3CoderToolParser under the name "qwen3_coder".
|
||||
|
||||
Deploy steps on the remote machine (already called by patch_ops.sh):
|
||||
1. cp qwen3coder_tool_parser.py \
|
||||
/usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
||||
2. python3 patch_vllm_tool_parser.py
|
||||
|
||||
Usage after patching:
|
||||
--tool-call-parser qwen3_coder --enable-auto-tool-choice
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
VLLM_ROOT = "/usr/local/corex/lib/python3/dist-packages/vllm"
|
||||
TOOL_PARSERS_DIR = f"{VLLM_ROOT}/entrypoints/openai/tool_parsers"
|
||||
INIT_FILE = f"{TOOL_PARSERS_DIR}/__init__.py"
|
||||
|
||||
|
||||
def patch_file(path, replacements):
|
||||
with open(path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
patched = False
|
||||
for old, new in replacements:
|
||||
if new in content:
|
||||
print(f" [skip] already patched: {repr(new[:70])}")
|
||||
continue
|
||||
if old not in content:
|
||||
print(f" [warn] anchor not found: {repr(old[:70])}")
|
||||
continue
|
||||
content = content.replace(old, new, 1)
|
||||
patched = True
|
||||
print(f" [ok] patched: {repr(old[:50])} -> {repr(new[:50])}")
|
||||
|
||||
if patched:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.isdir(TOOL_PARSERS_DIR):
|
||||
raise FileNotFoundError(
|
||||
f"Tool parsers directory not found: {TOOL_PARSERS_DIR}\n"
|
||||
"Verify the vLLM installation path.")
|
||||
|
||||
print(f"=== Patching {INIT_FILE} ===")
|
||||
patch_file(INIT_FILE, [
|
||||
(
|
||||
"from .mistral_tool_parser import MistralToolParser",
|
||||
"from .mistral_tool_parser import MistralToolParser\n"
|
||||
"from .qwen3coder_tool_parser import Qwen3CoderToolParser",
|
||||
),
|
||||
(
|
||||
'"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"\n]',
|
||||
'"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",\n'
|
||||
' "Qwen3CoderToolParser"\n]',
|
||||
),
|
||||
])
|
||||
|
||||
print("\n=== Verification ===")
|
||||
try:
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"qwen3coder_tool_parser",
|
||||
f"{TOOL_PARSERS_DIR}/qwen3coder_tool_parser.py",
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
print(f" Module spec loaded: {spec.name}")
|
||||
print(" (full import requires torch/vllm runtime — skipping exec)")
|
||||
except Exception as e:
|
||||
print(f" [warn] spec check failed: {e}")
|
||||
|
||||
print("\nDone. Start vLLM server with:")
|
||||
print(" --tool-call-parser qwen3_coder --enable-auto-tool-choice")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1020
qwen3_6_scripts/protocol.py
Normal file
1020
qwen3_6_scripts/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,12 +11,15 @@ from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -417,9 +420,6 @@ class GatedDeltaNet(nn.Module):
|
||||
|
||||
else:
|
||||
# Decode: one token per sequence
|
||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
||||
_f.flush()
|
||||
num_seqs = hidden_states.shape[0]
|
||||
weight_2d = self.conv1d_weight.squeeze(1)
|
||||
|
||||
@@ -449,17 +449,47 @@ class GatedDeltaNet(nn.Module):
|
||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||
|
||||
core_out, last_state = _torch_recurrent_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=temporal_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
# Inlined decode recurrent step (seq_len=1).
|
||||
# Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
|
||||
# contiguous+float32 copies, core_out allocation, and Python loop.
|
||||
# Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
|
||||
# temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
|
||||
orig_dtype = q.dtype
|
||||
_scale = self.head_k_dim ** -0.5
|
||||
|
||||
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
|
||||
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
|
||||
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
|
||||
g_t = g.squeeze(1).float().exp_() # (B, H_v)
|
||||
bt = beta.squeeze(1).float() # (B, H_v)
|
||||
|
||||
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
|
||||
temporal_state.mul_(g_t[:, :, None, None])
|
||||
|
||||
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
|
||||
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
|
||||
BH = ts_flat.shape[0]
|
||||
|
||||
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
|
||||
kv_mem = torch.bmm(
|
||||
k_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
|
||||
|
||||
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
|
||||
|
||||
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
|
||||
ts_flat.baddbmm_(
|
||||
k_t.view(BH, self.head_k_dim, 1),
|
||||
delta.view(BH, 1, self.head_v_dim),
|
||||
)
|
||||
if last_state is not None:
|
||||
temporal_state.copy_(last_state)
|
||||
|
||||
# Output: core_out = q_t @ updated temporal_state
|
||||
core_out = torch.bmm(
|
||||
q_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
|
||||
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
|
||||
|
||||
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||
normed = self.norm(
|
||||
core_out.reshape(-1, self.head_v_dim),
|
||||
z.reshape(-1, self.head_v_dim))
|
||||
@@ -495,24 +525,52 @@ class Qwen3_5FullAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.local_num_heads = self.num_heads // tp_size
|
||||
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# When num_kv_heads < tp_size we cannot shard KV further (would give
|
||||
# fractional heads per rank). Use ReplicatedLinear so every rank holds
|
||||
# all KV heads; local_num_kv_heads equals the full count.
|
||||
# When num_kv_heads >= tp_size standard ColumnParallel sharding applies.
|
||||
if tp_size > self.num_kv_heads:
|
||||
# GQA-aware TP sharding: ixformer kernel only supports num_kv_heads=1
|
||||
# per rank. With num_kv_heads=2 < tp_size=4 we cannot shard KV
|
||||
# evenly, but we CAN assign each rank the ONE KV head that serves
|
||||
# its Q heads:
|
||||
# q_per_kv = num_heads // num_kv_heads (e.g. 16//2 = 8)
|
||||
# Rank r uses KV head r * local_num_heads // q_per_kv
|
||||
# e.g. ranks 0,1 → KV head 0; ranks 2,3 → KV head 1.
|
||||
# We replicate all KV heads to every rank and select in forward().
|
||||
self.proj_kv_heads = self.num_kv_heads # heads available from projection
|
||||
self.local_num_kv_heads = 1 # heads after rank-local selection
|
||||
self.q_per_kv_global = self.num_heads // self.num_kv_heads
|
||||
self.k_proj = ReplicatedLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config)
|
||||
self.v_proj = ReplicatedLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config)
|
||||
else:
|
||||
# Standard sharding: each rank gets num_kv_heads // tp_size heads.
|
||||
self.local_num_kv_heads = self.num_kv_heads // tp_size
|
||||
self.proj_kv_heads = self.local_num_kv_heads # already sharded
|
||||
self.q_per_kv_global = None
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
|
||||
self.local_q_dim = self.local_num_heads * self.head_dim
|
||||
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# q_proj includes gate: output = num_heads * head_dim * 2
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_heads * self.head_dim * 2,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim, self.hidden_size,
|
||||
bias=False, quant_config=quant_config,
|
||||
@@ -559,18 +617,34 @@ class Qwen3_5FullAttention(nn.Module):
|
||||
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
|
||||
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
|
||||
|
||||
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
|
||||
k, _ = self.k_proj(hidden_states) # (total, proj_kv_heads * head_dim)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
|
||||
# Per-head RMSNorm
|
||||
# q_norm on local Q heads
|
||||
q = self.q_norm.forward_cuda(
|
||||
q.view(total_tokens, self.local_num_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
|
||||
# GQA-aware TP: select rank-local KV head BEFORE k_norm and rope so
|
||||
# that ixformer kernels always see num_kv_heads=1 (same as 27B path).
|
||||
# Doing k_norm/rope on 2 KV heads (proj_kv_heads=2) triggers ixformer
|
||||
# paths that can produce NaN; restricting to 1 head avoids the issue.
|
||||
if self.q_per_kv_global is not None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
kv_idx = (tp_rank * self.local_num_heads) // self.q_per_kv_global
|
||||
k = (k.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||
v = (v.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||
|
||||
# k_norm on the (now always 1) rank-local KV head
|
||||
k = self.k_norm.forward_cuda(
|
||||
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
|
||||
.contiguous()).view(total_tokens, -1)
|
||||
|
||||
# rope: q=(T, local_num_heads*head_dim), k=(T, 1*head_dim) — mirrors 27B
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
# Multiply by sigmoid gate before output projection
|
||||
@@ -609,10 +683,154 @@ class Qwen3_5MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MoE sparse block (Qwen3.5-MoE / Qwen3.6-35B-A3B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5MoeSparseBlock(nn.Module):
|
||||
"""Replaces Qwen3_5MLP for qwen3_5_moe_text layers.
|
||||
|
||||
FusedMoE is used ONLY for weight storage and loading (create_weights /
|
||||
weight_loader are pure PyTorch). Its forward kernel is bypassed because
|
||||
ixformer on BI-V100 lacks vllm_moe_topk_softmax / vllm_invoke_fused_moe_kernel.
|
||||
Routing and expert computation use a pure-PyTorch loop instead.
|
||||
|
||||
Shared expert uses RowParallelLinear(reduce_results=False) so both paths
|
||||
produce partial (pre-all-reduce) outputs that are combined before a single
|
||||
all-reduce.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_cfg,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = text_cfg.hidden_size
|
||||
self.num_experts = text_cfg.num_experts
|
||||
self.top_k = text_cfg.num_experts_per_tok
|
||||
|
||||
# Router: replicated (small: num_experts outputs)
|
||||
self.gate = ReplicatedLinear(hidden_size, text_cfg.num_experts,
|
||||
bias=False, quant_config=quant_config)
|
||||
|
||||
# FusedMoE: only used for weight storage + weight_loader.
|
||||
# Forward is bypassed — see _pure_pytorch_experts().
|
||||
self.experts = FusedMoE(
|
||||
num_experts=text_cfg.num_experts,
|
||||
top_k=text_cfg.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=text_cfg.moe_intermediate_size,
|
||||
reduce_results=False, # we do the all-reduce ourselves below
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Shared expert: defer all-reduce to combine with routed output first
|
||||
shared_size = text_cfg.shared_expert_intermediate_size
|
||||
self.shared_expert_gate_up = MergedColumnParallelLinear(
|
||||
hidden_size, [shared_size] * 2, bias=False,
|
||||
quant_config=quant_config)
|
||||
self.shared_expert_down = RowParallelLinear(
|
||||
shared_size, hidden_size, bias=False, reduce_results=False,
|
||||
quant_config=quant_config)
|
||||
self.act_fn = SiluAndMul()
|
||||
# Scalar sigmoid gate on shared expert output (same as Qwen2-MoE / Qwen3.5-MoE):
|
||||
# shared_out *= sigmoid(shared_expert_gate(hidden_states))
|
||||
# Without this, shared expert is always fully active → wrong logits.
|
||||
self.shared_expert_gate = ReplicatedLinear(
|
||||
hidden_size, 1, bias=False, quant_config=quant_config)
|
||||
|
||||
def _pure_pytorch_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Pure-PyTorch MoE (ixformer has no MoE kernels on BI-V100).
|
||||
|
||||
w13_weight: (num_experts, 2*inter_per_partition, hidden) [TP-sharded]
|
||||
w2_weight: (num_experts, hidden, inter_per_partition) [TP-sharded]
|
||||
Output is partial (pre-all-reduce), same contract as FusedMoE
|
||||
with reduce_results=False.
|
||||
"""
|
||||
# Routing: softmax → topk → renormalise
|
||||
routing_weights = torch.softmax(router_logits.float(), dim=-1)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1) # (T, top_k)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
w13 = self.experts.w13_weight # (E, 2*I, H)
|
||||
w2 = self.experts.w2_weight # (E, H, I)
|
||||
|
||||
T = hidden_states.shape[0]
|
||||
if T == 1:
|
||||
# Fast path: single token (decode).
|
||||
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||
# gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
|
||||
# down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
|
||||
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||
eids = topk_ids[0] # (K,)
|
||||
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||
w13_sel = w13[eids] # (K, 2*I, H)
|
||||
w2_sel = w2[eids] # (K, H, I)
|
||||
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
gate_up = F.linear(
|
||||
hidden_states,
|
||||
w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
|
||||
) # (1, K*2*I)
|
||||
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||
act = F.silu(gate) * up # (K, I)
|
||||
|
||||
# bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
|
||||
expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
|
||||
|
||||
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||
hidden_states.dtype) # (1, H)
|
||||
else:
|
||||
# General path (prefill / multi-seq): loop over unique active experts.
|
||||
# At most T*top_k unique experts, always <= num_experts.
|
||||
out = torch.zeros_like(hidden_states)
|
||||
unique_eids = topk_ids.view(-1).unique().tolist()
|
||||
for eid in unique_eids:
|
||||
eid = int(eid)
|
||||
mask = (topk_ids == eid) # (T, top_k)
|
||||
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||
tokens = hidden_states[tok_ids] # (n, H)
|
||||
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
act = F.silu(gate) * up # (n, I)
|
||||
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||
|
||||
return out # partial, all-reduce done in forward()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
routed_out = self._pure_pytorch_experts(hidden_states, router_logits)
|
||||
|
||||
gate_up, _ = self.shared_expert_gate_up(hidden_states)
|
||||
shared_out = self.act_fn(gate_up)
|
||||
shared_out, _ = self.shared_expert_down(shared_out)
|
||||
# Scalar sigmoid gate (Qwen2-MoE / Qwen3.5-MoE style)
|
||||
gate_score, _ = self.shared_expert_gate(hidden_states) # (T, 1)
|
||||
shared_out = shared_out * torch.sigmoid(gate_score)
|
||||
|
||||
out = routed_out + shared_out
|
||||
if self.experts.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Qwen3_5DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -623,6 +841,7 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = layer_type
|
||||
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
||||
eps=text_cfg.rms_norm_eps)
|
||||
@@ -640,12 +859,15 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
prefix=f"layers.{layer_idx}.self_attn",
|
||||
)
|
||||
|
||||
self.mlp = Qwen3_5MLP(
|
||||
hidden_size=text_cfg.hidden_size,
|
||||
intermediate_size=text_cfg.intermediate_size,
|
||||
hidden_act=text_cfg.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if getattr(text_cfg, 'model_type', '') == 'qwen3_5_moe_text':
|
||||
self.mlp = Qwen3_5MoeSparseBlock(text_cfg, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = Qwen3_5MLP(
|
||||
hidden_size=text_cfg.hidden_size,
|
||||
intermediate_size=text_cfg.intermediate_size,
|
||||
hidden_act=text_cfg.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -673,7 +895,9 @@ class Qwen3_5DecoderLayer(nn.Module):
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -860,8 +1084,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# With chunked prefill, intermediate chunks have seq_groups=None on all
|
||||
# ranks; _apply_logits_processors is guarded against this in
|
||||
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -892,12 +1117,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
or name.startswith("model.mtp")):
|
||||
continue
|
||||
|
||||
# Remap checkpoint prefix → module path
|
||||
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
|
||||
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
|
||||
# Prefix remapping: checkpoint may wrap under language_model
|
||||
if name.startswith("model.language_model."):
|
||||
name = "model." + name[len("model.language_model."):]
|
||||
# lm_head is already at top level — no change needed
|
||||
|
||||
# Skip positional embedding caches
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -931,3 +1153,118 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Qwen3.6-35B-A3B (Qwen3_5-MoE architecture)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
|
||||
"""Qwen3.6-35B-A3B: same hybrid-attention backbone as 27B, dense MLP
|
||||
replaced by Qwen3_5MoeSparseBlock (256 routed experts + shared expert).
|
||||
Only load_weights differs from the dense variant.
|
||||
"""
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Checkpoint key format for this model (transformers Qwen3_5MoeExperts):
|
||||
# mlp.experts.gate_up_proj shape (num_experts, 2*intermediate, hidden)
|
||||
# mlp.experts.down_proj shape (num_experts, hidden, intermediate)
|
||||
# mlp.gate.weight shape (num_experts, hidden) [router]
|
||||
# mlp.shared_expert.{gate,up,down}_proj.weight [shared MLP]
|
||||
# Our FusedMoE stores:
|
||||
# mlp.experts.w13_weight shape (num_experts, 2*intermediate//tp, hidden)
|
||||
# mlp.experts.w2_weight shape (num_experts, hidden, intermediate//tp)
|
||||
# Our shared expert stores:
|
||||
# mlp.shared_expert_gate_up.weight (merged gate+up)
|
||||
# mlp.shared_expert_down.weight
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, weight_name, shard_id)
|
||||
# shared expert
|
||||
("shared_expert_gate_up", "shared_expert.gate_proj", 0),
|
||||
("shared_expert_gate_up", "shared_expert.up_proj", 1),
|
||||
# linear_attention dense proj (same as 27B)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Skip vision and MTP branches
|
||||
if (name.startswith("model.visual")
|
||||
or name.startswith("mtp.")
|
||||
or name.startswith("model.mtp")):
|
||||
continue
|
||||
|
||||
# Prefix remapping for VL checkpoint (Qwen3_5MoeForConditionalGeneration):
|
||||
# model.language_model.model.{layers,embed_tokens,norm} -> model.{...}
|
||||
# model.language_model.lm_head -> lm_head
|
||||
# Prefix remapping: checkpoint may wrap under language_model
|
||||
if name.startswith("model.language_model."):
|
||||
name = "model." + name[len("model.language_model."):]
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if ".linear_attn.conv1d.weight" in name:
|
||||
name = name.replace(".linear_attn.conv1d.weight",
|
||||
".linear_attn.conv1d_weight")
|
||||
|
||||
# --- Fused routed-expert weights (all experts in one tensor) ---
|
||||
|
||||
if "mlp.experts.gate_up_proj" in name:
|
||||
# loaded_weight: (num_experts, 2*intermediate, hidden)
|
||||
w13_name = name.replace("mlp.experts.gate_up_proj",
|
||||
"mlp.experts.w13_weight")
|
||||
if w13_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[w13_name]
|
||||
n_exp = loaded_weight.shape[0]
|
||||
inter = loaded_weight.shape[1] // 2
|
||||
gate_w = loaded_weight[:, :inter, :].contiguous()
|
||||
up_w = loaded_weight[:, inter:, :].contiguous()
|
||||
for eid in range(n_exp):
|
||||
param.weight_loader(param, gate_w[eid], "w1_weight", "w1", eid)
|
||||
param.weight_loader(param, up_w[eid], "w3_weight", "w3", eid)
|
||||
continue
|
||||
|
||||
if "mlp.experts.down_proj" in name:
|
||||
# loaded_weight: (num_experts, hidden, intermediate)
|
||||
w2_name = name.replace("mlp.experts.down_proj",
|
||||
"mlp.experts.w2_weight")
|
||||
if w2_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[w2_name]
|
||||
n_exp = loaded_weight.shape[0]
|
||||
for eid in range(n_exp):
|
||||
param.weight_loader(param, loaded_weight[eid], "w2_weight", "w2", eid)
|
||||
continue
|
||||
|
||||
# --- Shared expert down_proj rename ---
|
||||
if "mlp.shared_expert.down_proj" in name:
|
||||
name = name.replace("mlp.shared_expert.down_proj",
|
||||
"mlp.shared_expert_down")
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
|
||||
# --- Stacked / standard weights ---
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name not in params_dict:
|
||||
break
|
||||
param = params_dict[name]
|
||||
param.weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
3
qwen3_6_scripts/qwen3_5_moe/__init__.py
Normal file
3
qwen3_6_scripts/qwen3_5_moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig
|
||||
|
||||
__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
|
||||
198
qwen3_6_scripts/qwen3_5_moe/configuration_qwen3_5_moe.py
Normal file
198
qwen3_6_scripts/qwen3_5_moe/configuration_qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Adapted from transformers 5.2.0 for compatibility with transformers 4.55.3 + torch 2.1.0
|
||||
# Source: transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
|
||||
# Stubs layer_type_validation and RopeParameters which do not exist in 4.55.3
|
||||
# Removes ignore_keys_at_rope_validation / base_model_tp_plan / base_model_pp_plan
|
||||
# which are 5.x-only and irrelevant for vLLM inference.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig as PreTrainedConfig
|
||||
|
||||
# --- Local stubs for APIs not present in transformers 4.55.3 ---
|
||||
def layer_type_validation(layer_types, num_hidden_layers=None, attention=True):
|
||||
allowed = {"full_attention", "linear_attention"}
|
||||
if not all(lt in allowed for lt in layer_types):
|
||||
raise ValueError(f"layer_types entries must be in {allowed}, got {layer_types}")
|
||||
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
|
||||
raise ValueError(
|
||||
f"num_hidden_layers ({num_hidden_layers}) != len(layer_types) ({len(layer_types)})"
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
class RopeParameters(TypedDict, total=False):
|
||||
rope_theta: float
|
||||
rope_type: str
|
||||
partial_rotary_factor: float
|
||||
factor: float
|
||||
except Exception:
|
||||
RopeParameters = dict
|
||||
|
||||
# --- End stubs ---
|
||||
|
||||
|
||||
class Qwen3_5MoeTextConfig(PreTrainedConfig):
|
||||
r"""
|
||||
Configuration for the text backbone of Qwen3.5-MoE / Qwen3.6-35B-A3B models.
|
||||
model_type is "qwen3_5_moe_text" (used internally by the nested config).
|
||||
"""
|
||||
|
||||
model_type = "qwen3_5_moe_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=248320,
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_parameters=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
head_dim=256,
|
||||
linear_conv_kernel_dim=4,
|
||||
linear_key_head_dim=128,
|
||||
linear_value_head_dim=128,
|
||||
linear_num_key_heads=16,
|
||||
linear_num_value_heads=32,
|
||||
moe_intermediate_size=512,
|
||||
shared_expert_intermediate_size=512,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=256,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
layer_types=None,
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
eos_token_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim
|
||||
self.rope_parameters = rope_parameters
|
||||
kwargs.setdefault("partial_rotary_factor", 0.25)
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
interval_pattern = kwargs.get("full_attention_interval", 4)
|
||||
self.layer_types = [
|
||||
"linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
||||
|
||||
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||
self.linear_key_head_dim = linear_key_head_dim
|
||||
self.linear_value_head_dim = linear_value_head_dim
|
||||
self.linear_num_key_heads = linear_num_key_heads
|
||||
self.linear_num_value_heads = linear_num_value_heads
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Qwen3_5MoeVisionConfig(PreTrainedConfig):
|
||||
model_type = "qwen3_5_moe"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=27,
|
||||
hidden_size=1152,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
intermediate_size=4304,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=16,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
out_hidden_size=3584,
|
||||
num_position_embeddings=2304,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.num_position_embeddings = num_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class Qwen3_5MoeConfig(PreTrainedConfig):
|
||||
r"""
|
||||
Top-level configuration for Qwen3.5-MoE / Qwen3.6-35B-A3B.
|
||||
model_type = "qwen3_5_moe" matches the model card / config.json.
|
||||
Wraps Qwen3_5MoeTextConfig (and optionally Qwen3_5MoeVisionConfig).
|
||||
For vLLM text-only inference only text_config is consumed.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_5_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=248056,
|
||||
video_token_id=248057,
|
||||
vision_start_token_id=248053,
|
||||
vision_end_token_id=248054,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = Qwen3_5MoeTextConfig(**text_config)
|
||||
elif text_config is None:
|
||||
self.text_config = Qwen3_5MoeTextConfig()
|
||||
else:
|
||||
self.text_config = text_config
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = Qwen3_5MoeVisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = Qwen3_5MoeVisionConfig()
|
||||
else:
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
|
||||
509
qwen3_6_scripts/qwen3coder_tool_parser.py
Normal file
509
qwen3_6_scripts/qwen3coder_tool_parser.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("qwen3_coder")
|
||||
class Qwen3CoderToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for Qwen3 models using XML-style tool call format:
|
||||
<tool_call><function=name><parameter=key>
|
||||
value
|
||||
</parameter></function></tool_call>
|
||||
|
||||
Port of vllm-original qwen3coder_tool_parser.py to vllm 0.6.3 API.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# Base class uses int; we override with string IDs
|
||||
self.current_tool_id: Optional[str] = None # type: ignore[assignment]
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_prefix: str = "<function="
|
||||
self.function_end_token: str = "</function>"
|
||||
self.parameter_prefix: str = "<parameter="
|
||||
self.parameter_end_token: str = "</parameter>"
|
||||
self.is_tool_call_started: bool = False
|
||||
|
||||
self._reset_streaming_state()
|
||||
|
||||
self.tool_call_complete_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
|
||||
self.tool_call_function_regex = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||
self.tool_call_parameter_regex = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Qwen3 XML Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
logger.debug("vLLM Successfully imported tool parser %s !",
|
||||
self.__class__.__name__)
|
||||
|
||||
|
||||
def _generate_tool_call_id(self) -> str:
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
self.current_tool_index = 0
|
||||
self.is_tool_call_started = False
|
||||
self.header_sent = False
|
||||
self.current_tool_id = None
|
||||
self.current_function_name: Optional[str] = None
|
||||
self.current_param_name: Optional[str] = None
|
||||
self.current_param_value: str = ""
|
||||
self.param_count = 0
|
||||
self.in_param = False
|
||||
self.in_function = False
|
||||
self.accumulated_text: str = ""
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.accumulated_params: Dict[str, Any] = {}
|
||||
self.streaming_request: Optional[ChatCompletionRequest] = None
|
||||
|
||||
def _get_arguments_config(
|
||||
self, func_name: str,
|
||||
tools: Optional[List[ChatCompletionToolsParam]]) -> Dict:
|
||||
if tools is None:
|
||||
return {}
|
||||
for config in tools:
|
||||
if not hasattr(config, "type") or not (
|
||||
hasattr(config, "function")
|
||||
and hasattr(config.function, "name")):
|
||||
continue
|
||||
if config.type == "function" and config.function.name == func_name:
|
||||
if not hasattr(config.function, "parameters"):
|
||||
return {}
|
||||
params = config.function.parameters
|
||||
if isinstance(params, dict) and "properties" in params:
|
||||
return params["properties"]
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
return {}
|
||||
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
|
||||
return {}
|
||||
|
||||
def _convert_param_value(self, param_value: str, param_name: str,
|
||||
param_config: Dict, func_name: str) -> Any:
|
||||
if param_value.lower() == "null":
|
||||
return None
|
||||
|
||||
if param_name not in param_config:
|
||||
if param_config != {}:
|
||||
logger.debug(
|
||||
"Parsed parameter '%s' is not defined in tool '%s', "
|
||||
"returning string value.", param_name, func_name)
|
||||
return param_value
|
||||
|
||||
if (isinstance(param_config[param_name], dict)
|
||||
and "type" in param_config[param_name]):
|
||||
param_type = str(
|
||||
param_config[param_name]["type"]).strip().lower()
|
||||
else:
|
||||
param_type = "string"
|
||||
|
||||
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
||||
return param_value
|
||||
elif (param_type.startswith("int") or param_type.startswith("uint")
|
||||
or param_type.startswith("long")
|
||||
or param_type.startswith("short")
|
||||
or param_type.startswith("unsigned")):
|
||||
try:
|
||||
return int(param_value)
|
||||
except (ValueError, TypeError):
|
||||
return param_value
|
||||
elif param_type.startswith("num") or param_type.startswith("float"):
|
||||
try:
|
||||
v = float(param_value)
|
||||
return int(v) if v - int(v) == 0 else v
|
||||
except (ValueError, TypeError):
|
||||
return param_value
|
||||
elif param_type in ["boolean", "bool", "binary"]:
|
||||
lower = param_value.lower()
|
||||
if lower not in ["true", "false"]:
|
||||
logger.debug(
|
||||
"Parameter '%s' value '%s' is not boolean in tool '%s'.",
|
||||
param_name, param_value, func_name)
|
||||
return lower == "true"
|
||||
else:
|
||||
if (param_type in ["object", "array", "arr"]
|
||||
or param_type.startswith("dict")
|
||||
or param_type.startswith("list")):
|
||||
try:
|
||||
return json.loads(param_value)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(param_value)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
return param_value
|
||||
|
||||
def _parse_xml_function_call(
|
||||
self, function_call_str: str,
|
||||
tools: Optional[List[ChatCompletionToolsParam]]) -> ToolCall:
|
||||
end_index = function_call_str.index(">")
|
||||
function_name = function_call_str[:end_index]
|
||||
param_config = self._get_arguments_config(function_name, tools)
|
||||
parameters = function_call_str[end_index + 1:]
|
||||
param_dict: Dict[str, Any] = {}
|
||||
for match_text in self.tool_call_parameter_regex.findall(parameters):
|
||||
idx = match_text.index(">")
|
||||
param_name = match_text[:idx]
|
||||
param_value = str(match_text[idx + 1:])
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
param_dict[param_name] = self._convert_param_value(
|
||||
param_value, param_name, param_config, function_name)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False)))
|
||||
|
||||
def _get_function_calls(self, model_output: str) -> List[str]:
|
||||
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||
raw_tool_calls = [
|
||||
match[0] if match[0] else match[1] for match in matched_ranges
|
||||
]
|
||||
if not raw_tool_calls:
|
||||
raw_tool_calls = [model_output]
|
||||
raw_function_calls: List[tuple] = []
|
||||
for tool_call in raw_tool_calls:
|
||||
raw_function_calls.extend(
|
||||
self.tool_call_function_regex.findall(tool_call))
|
||||
return [match[0] if match[0] else match[1]
|
||||
for match in raw_function_calls]
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
if self.tool_call_prefix not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
function_calls = self._get_function_calls(model_output)
|
||||
if not function_calls:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
tool_calls = [
|
||||
self._parse_xml_function_call(fc, request.tools)
|
||||
for fc in function_calls
|
||||
]
|
||||
|
||||
self.prev_tool_call_arr.clear()
|
||||
for tc in tool_calls:
|
||||
self.prev_tool_call_arr.append({
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
})
|
||||
|
||||
content_index = model_output.find(self.tool_call_start_token)
|
||||
idx = model_output.find(self.tool_call_prefix)
|
||||
content_index = content_index if content_index >= 0 else idx
|
||||
content = model_output[:content_index]
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=bool(tool_calls),
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
self.streaming_request = request
|
||||
|
||||
if not delta_text:
|
||||
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||
complete_calls = len(
|
||||
self.tool_call_complete_regex.findall(current_text))
|
||||
if complete_calls > 0 and self.prev_tool_call_arr:
|
||||
open_calls = (
|
||||
current_text.count(self.tool_call_start_token) -
|
||||
current_text.count(self.tool_call_end_token))
|
||||
if open_calls == 0:
|
||||
return DeltaMessage(content="")
|
||||
elif not self.is_tool_call_started and current_text:
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
self.accumulated_text = current_text
|
||||
|
||||
if self.json_closed and not self.in_function:
|
||||
tool_ends = current_text.count(self.tool_call_end_token)
|
||||
if tool_ends > self.current_tool_index:
|
||||
self.current_tool_index += 1
|
||||
self.header_sent = False
|
||||
self.param_count = 0
|
||||
self.json_started = False
|
||||
self.json_closed = False
|
||||
self.accumulated_params = {}
|
||||
tool_starts = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts:
|
||||
self.is_tool_call_started = False
|
||||
return None
|
||||
|
||||
if not self.is_tool_call_started:
|
||||
if (self.tool_call_start_token_id in delta_token_ids
|
||||
or self.tool_call_start_token in delta_text):
|
||||
self.is_tool_call_started = True
|
||||
if self.tool_call_start_token in delta_text:
|
||||
content_before = delta_text[:delta_text.index(
|
||||
self.tool_call_start_token)]
|
||||
if content_before:
|
||||
return DeltaMessage(content=content_before)
|
||||
return None
|
||||
else:
|
||||
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||
and delta_text.strip() == ""):
|
||||
return None
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||
if self.current_tool_index >= tool_starts_count:
|
||||
return None
|
||||
|
||||
# Locate the current tool call's text slice
|
||||
tool_start_positions: List[int] = []
|
||||
search = 0
|
||||
while True:
|
||||
search = current_text.find(self.tool_call_start_token, search)
|
||||
if search == -1:
|
||||
break
|
||||
tool_start_positions.append(search)
|
||||
search += len(self.tool_call_start_token)
|
||||
|
||||
if self.current_tool_index >= len(tool_start_positions):
|
||||
return None
|
||||
|
||||
tool_start_idx = tool_start_positions[self.current_tool_index]
|
||||
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||
tool_start_idx)
|
||||
if tool_end_idx == -1:
|
||||
tool_text = current_text[tool_start_idx:]
|
||||
else:
|
||||
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||
len(self.tool_call_end_token)]
|
||||
|
||||
if not self.header_sent:
|
||||
if self.tool_call_prefix in tool_text:
|
||||
func_start = (tool_text.find(self.tool_call_prefix) +
|
||||
len(self.tool_call_prefix))
|
||||
func_end = tool_text.find(">", func_start)
|
||||
if func_end != -1:
|
||||
self.current_function_name = tool_text[func_start:func_end]
|
||||
self.current_tool_id = self._generate_tool_call_id()
|
||||
self.header_sent = True
|
||||
self.in_function = True
|
||||
self.prev_tool_call_arr.append({
|
||||
"name": self.current_function_name,
|
||||
"arguments": "{}",
|
||||
})
|
||||
self.streamed_args_for_tool.append("")
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
id=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=self.current_function_name,
|
||||
arguments=""),
|
||||
type="function",
|
||||
)
|
||||
])
|
||||
return None
|
||||
|
||||
if self.in_function:
|
||||
if not self.json_started:
|
||||
self.json_started = True
|
||||
self.streamed_args_for_tool[self.current_tool_index] += "{"
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="{"),
|
||||
)
|
||||
])
|
||||
|
||||
# Collect all complete parameters in one pass (speculative-decode safe)
|
||||
param_starts: List[int] = []
|
||||
search = 0
|
||||
while True:
|
||||
search = tool_text.find(self.parameter_prefix, search)
|
||||
if search == -1:
|
||||
break
|
||||
param_starts.append(search)
|
||||
search += len(self.parameter_prefix)
|
||||
|
||||
json_fragments: List[str] = []
|
||||
while not self.in_param and self.param_count < len(param_starts):
|
||||
param_idx = param_starts[self.param_count]
|
||||
param_start = param_idx + len(self.parameter_prefix)
|
||||
remaining = tool_text[param_start:]
|
||||
|
||||
if ">" not in remaining:
|
||||
break
|
||||
|
||||
name_end = remaining.find(">")
|
||||
current_param_name = remaining[:name_end]
|
||||
value_start = param_start + name_end + 1
|
||||
value_text = tool_text[value_start:]
|
||||
if value_text.startswith("\n"):
|
||||
value_text = value_text[1:]
|
||||
|
||||
param_end_idx = value_text.find(self.parameter_end_token)
|
||||
if param_end_idx == -1:
|
||||
next_param = value_text.find(self.parameter_prefix)
|
||||
func_end = value_text.find(self.function_end_token)
|
||||
if next_param != -1 and (func_end == -1
|
||||
or next_param < func_end):
|
||||
param_end_idx = next_param
|
||||
elif func_end != -1:
|
||||
param_end_idx = func_end
|
||||
else:
|
||||
tool_end_in_value = value_text.find(
|
||||
self.tool_call_end_token)
|
||||
if tool_end_in_value != -1:
|
||||
param_end_idx = tool_end_in_value
|
||||
else:
|
||||
break
|
||||
|
||||
if param_end_idx == -1:
|
||||
break
|
||||
|
||||
param_value = value_text[:param_end_idx]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
self.accumulated_params[current_param_name] = param_value
|
||||
param_config = self._get_arguments_config(
|
||||
self.current_function_name or "",
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request else None)
|
||||
converted = self._convert_param_value(
|
||||
param_value, current_param_name, param_config,
|
||||
self.current_function_name or "")
|
||||
serialized = json.dumps(converted, ensure_ascii=False)
|
||||
|
||||
sep = "" if self.param_count == 0 else ", "
|
||||
json_fragments.append(
|
||||
f'{sep}"{current_param_name}": {serialized}')
|
||||
self.param_count += 1
|
||||
|
||||
if json_fragments:
|
||||
combined = "".join(json_fragments)
|
||||
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_index] += combined
|
||||
else:
|
||||
logger.warning(
|
||||
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||
self.current_tool_index,
|
||||
len(self.streamed_args_for_tool))
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments=combined),
|
||||
)
|
||||
])
|
||||
|
||||
# Emit closing brace when </function> is seen (after params are done)
|
||||
if not self.json_closed and self.function_end_token in tool_text:
|
||||
self.json_closed = True
|
||||
func_start = (tool_text.find(self.tool_call_prefix) +
|
||||
len(self.tool_call_prefix))
|
||||
func_content_end = tool_text.find(self.function_end_token,
|
||||
func_start)
|
||||
if func_content_end != -1:
|
||||
try:
|
||||
parsed_tool = self._parse_xml_function_call(
|
||||
tool_text[func_start:func_content_end],
|
||||
self.streaming_request.tools
|
||||
if self.streaming_request else None)
|
||||
if self.current_tool_index < len(
|
||||
self.prev_tool_call_arr):
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_index]["arguments"] = (
|
||||
parsed_tool.function.arguments)
|
||||
except Exception:
|
||||
logger.debug("Failed to parse tool call during "
|
||||
"streaming: %s",
|
||||
tool_text,
|
||||
exc_info=True)
|
||||
|
||||
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_index] += "}"
|
||||
else:
|
||||
logger.warning(
|
||||
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||
self.current_tool_index,
|
||||
len(self.streamed_args_for_tool))
|
||||
|
||||
result = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_index,
|
||||
function=DeltaFunctionCall(arguments="}"),
|
||||
)
|
||||
])
|
||||
self.in_function = False
|
||||
self.accumulated_params = {}
|
||||
return result
|
||||
|
||||
return None
|
||||
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Reasoning parser module for vLLM 0.6.3 (BI-V100 / Qwen3.6-27B adaptation).
|
||||
|
||||
Usage: --reasoning-parser qwen3
|
||||
"""
|
||||
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
|
||||
__all__ = ["ReasoningParser", "ReasoningParserManager"]
|
||||
|
||||
# Lazy-register Qwen3 parser; imported on first get_reasoning_parser("qwen3").
|
||||
ReasoningParserManager.register_lazy(
|
||||
"qwen3",
|
||||
"vllm.reasoning.qwen3_reasoning_parser",
|
||||
"Qwen3ReasoningParser",
|
||||
)
|
||||
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Abstract reasoning parser base classes for vLLM 0.6.3.
|
||||
Adapted from vllm-original/vllm/reasoning/abs_reasoning_parsers.py:
|
||||
- Removed vllm.entrypoints.mcp, vllm.utils.collection_utils, import_utils
|
||||
- DeltaMessage from vllm 0.6.3 protocol path
|
||||
- TokenizerLike -> AnyTokenizer
|
||||
- ReasoningParserManager: simplified eager + lazy registration
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
DeltaMessage = Any
|
||||
AnyTokenizer = Any
|
||||
|
||||
|
||||
class ReasoningParser:
|
||||
"""Abstract base for all reasoning parsers."""
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict:
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""Return True once the reasoning block has closed in input_ids."""
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
return self.is_reasoning_end(input_ids)
|
||||
|
||||
@abstractmethod
|
||||
def extract_content_ids(self, input_ids: list) -> list:
|
||||
"""Return token ids that belong to the content (post-reasoning) part."""
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
"""
|
||||
Split a complete model output into (reasoning_text, content_text).
|
||||
Either part may be None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Optional["DeltaMessage"]:
|
||||
"""
|
||||
Extract reasoning from a streaming delta.
|
||||
Returns a DeltaMessage with reasoning_content and/or content set,
|
||||
or None if this delta should be suppressed (control token).
|
||||
"""
|
||||
|
||||
|
||||
class BaseThinkingReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Base for parsers that use <start_token>...</end_token> delimiters.
|
||||
Subclasses define start_token / end_token properties.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def start_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def end_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError("Tokenizer must be passed to ReasoningParser.")
|
||||
if not self.start_token or not self.end_token:
|
||||
raise ValueError("start_token and end_token must be defined.")
|
||||
|
||||
self.start_token_id: Optional[int] = self.vocab.get(self.start_token)
|
||||
self.end_token_id: Optional[int] = self.vocab.get(self.end_token)
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}: could not find think tokens "
|
||||
f"'{self.start_token}'/'{self.end_token}' in tokenizer vocab."
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
for token_id in reversed(input_ids):
|
||||
if token_id == self.start_token_id:
|
||||
return False
|
||||
if token_id == self.end_token_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
return self.end_token_id in delta_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list) -> list:
|
||||
if self.end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
count = 0
|
||||
depth = 0
|
||||
for tid in token_ids:
|
||||
if tid == self.start_token_id:
|
||||
depth += 1
|
||||
elif tid == self.end_token_id:
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
elif depth > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
# Strip <think> if the model generated it (old-style template).
|
||||
parts = model_output.partition(self.start_token)
|
||||
model_output = parts[2] if parts[1] else parts[0]
|
||||
|
||||
if self.end_token not in model_output:
|
||||
return model_output, None
|
||||
reasoning, _, content = model_output.partition(self.end_token)
|
||||
return reasoning, content or None
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Optional["DeltaMessage"]:
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage as _DeltaMessage
|
||||
|
||||
# Suppress lone control tokens.
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] in (
|
||||
self.start_token_id, self.end_token_id
|
||||
):
|
||||
return None
|
||||
|
||||
start_in_prev = self.start_token_id in previous_token_ids
|
||||
start_in_delta = self.start_token_id in delta_token_ids
|
||||
end_in_prev = self.end_token_id in previous_token_ids
|
||||
end_in_delta = self.end_token_id in delta_token_ids
|
||||
|
||||
if start_in_prev:
|
||||
if end_in_delta:
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
reasoning = delta_text[:end_idx] if end_idx >= 0 else ""
|
||||
content = delta_text[end_idx + len(self.end_token):] if end_idx >= 0 else None
|
||||
return _DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
elif end_in_prev:
|
||||
return _DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return _DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
elif start_in_delta:
|
||||
if end_in_delta:
|
||||
start_idx = delta_text.find(self.start_token)
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
reasoning = delta_text[start_idx + len(self.start_token):end_idx]
|
||||
content = delta_text[end_idx + len(self.end_token):]
|
||||
return _DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
else:
|
||||
return _DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
else:
|
||||
return _DeltaMessage(content=delta_text)
|
||||
|
||||
|
||||
class ReasoningParserManager:
|
||||
"""
|
||||
Registry for ReasoningParser implementations.
|
||||
Supports eager and lazy registration.
|
||||
"""
|
||||
|
||||
_parsers: dict = {} # name -> class (eager)
|
||||
_lazy: dict = {} # name -> (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def register_module(cls, name: str, parser_cls: type) -> None:
|
||||
"""Eagerly register a ReasoningParser class."""
|
||||
if not issubclass(parser_cls, ReasoningParser):
|
||||
raise TypeError(f"{parser_cls} is not a ReasoningParser subclass.")
|
||||
cls._parsers[name] = parser_cls
|
||||
|
||||
@classmethod
|
||||
def register_lazy(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a parser for deferred import."""
|
||||
cls._lazy[name] = (module_path, class_name)
|
||||
|
||||
@classmethod
|
||||
def get_reasoning_parser(cls, name: str) -> type:
|
||||
if name in cls._parsers:
|
||||
return cls._parsers[name]
|
||||
if name in cls._lazy:
|
||||
module_path, class_name = cls._lazy[name]
|
||||
mod = importlib.import_module(module_path)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
cls._parsers[name] = parser_cls
|
||||
return parser_cls
|
||||
registered = sorted(set(cls._parsers) | set(cls._lazy))
|
||||
raise KeyError(
|
||||
f"Reasoning parser '{name}' not found. "
|
||||
f"Available: {registered}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
return sorted(set(cls._parsers) | set(cls._lazy))
|
||||
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Reasoning parser for Qwen3 / Qwen3.5 / Qwen3.6 model family.
|
||||
Adapted from vllm-original/vllm/reasoning/qwen3_reasoning_parser.py.
|
||||
|
||||
The model uses <think>...</think> to wrap chain-of-thought output.
|
||||
For Qwen3.5+ the chat template injects <think> into the prompt, so only
|
||||
</think> appears in the generated tokens; older templates generate <think>
|
||||
themselves. Both styles are handled.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence, Any
|
||||
|
||||
from vllm.reasoning.abs_reasoning_parsers import (
|
||||
BaseThinkingReasoningParser,
|
||||
ReasoningParserManager,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer: Any, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return "<think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return "</think>"
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: Any
|
||||
) -> "tuple[Optional[str], Optional[str]]":
|
||||
# Strip <think> if the model generated it (old template / edge case).
|
||||
parts = model_output.partition(self.start_token)
|
||||
model_output = parts[2] if parts[1] else parts[0]
|
||||
|
||||
if self.end_token not in model_output:
|
||||
if not self.thinking_enabled:
|
||||
return None, model_output
|
||||
# Thinking enabled but output truncated before </think>.
|
||||
return model_output, None
|
||||
|
||||
reasoning, _, content = model_output.partition(self.end_token)
|
||||
return reasoning, content or None
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
token_ids = list(token_ids)
|
||||
if self.start_token_id in token_ids:
|
||||
# Old-style template: model generates <think> itself.
|
||||
# Use depth-counting from the base class.
|
||||
return super().count_reasoning_tokens(token_ids)
|
||||
elif self.end_token_id in token_ids:
|
||||
# New-style template (Qwen3.5+): <think> is injected into the
|
||||
# prompt, so output starts already inside the thinking block.
|
||||
# Every token before </think> is a reasoning token.
|
||||
return token_ids.index(self.end_token_id)
|
||||
else:
|
||||
# No </think> in output: either truncated (all reasoning)
|
||||
# or thinking disabled (none).
|
||||
return len(token_ids) if self.thinking_enabled else 0
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
):
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||
|
||||
if not self.thinking_enabled:
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# Strip <think> from delta if the model generates it itself.
|
||||
if self.start_token_id in delta_token_ids:
|
||||
start_idx = delta_text.find(self.start_token)
|
||||
if start_idx >= 0:
|
||||
delta_text = delta_text[start_idx + len(self.start_token):]
|
||||
|
||||
if self.end_token_id in delta_token_ids:
|
||||
end_idx = delta_text.find(self.end_token)
|
||||
if end_idx >= 0:
|
||||
reasoning = delta_text[:end_idx]
|
||||
content = delta_text[end_idx + len(self.end_token):]
|
||||
if not reasoning and not content:
|
||||
return None
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning or None,
|
||||
content=content or None,
|
||||
)
|
||||
return None
|
||||
|
||||
if not delta_text:
|
||||
return None
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
|
||||
# Register immediately when this module is imported.
|
||||
ReasoningParserManager.register_module("qwen3", Qwen3ReasoningParser)
|
||||
1385
qwen3_6_scripts/sequence.py
Normal file
1385
qwen3_6_scripts/sequence.py
Normal file
File diff suppressed because it is too large
Load Diff
999
qwen3_6_scripts/serving_chat.py
Normal file
999
qwen3_6_scripts/serving_chat.py
Normal file
@@ -0,0 +1,999 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||
ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
reasoning_parser: Optional[str] = None):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
self.use_tool_use_model_template = False
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
if self.enable_auto_tools:
|
||||
logger.info(
|
||||
"\"auto\" tool choice has been enabled please note that while"
|
||||
" the parallel_tool_calls client option is preset for "
|
||||
"compatibility reasons, it will be ignored.")
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
try:
|
||||
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||
tool_parser)
|
||||
except Exception as e:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
# set up reasoning parser
|
||||
self.reasoning_parser_cls = None
|
||||
if reasoning_parser:
|
||||
try:
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
self.reasoning_parser_cls = \
|
||||
ReasoningParserManager.get_reasoning_parser(reasoning_parser)
|
||||
logger.info("Reasoning parser '%s' enabled.", reasoning_parser)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
f"Error: --reasoning-parser '{reasoning_parser}' could not "
|
||||
"be loaded. Make sure vllm/reasoning/ is installed."
|
||||
) from e
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
ChatCompletion API.
|
||||
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in applying chat template from request")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
mm_data = await mm_data_future
|
||||
except Exception as e:
|
||||
logger.exception("Error in loading multi-modal data")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||
self.enable_auto_tools and self.tool_parser is not None):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
"\"auto\" tool choice requires "
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
assert isinstance(prompt, list) and isinstance(
|
||||
prompt[0], int
|
||||
), "Prompt has to be either a string or a list of token ids"
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
assert isinstance(self.engine_client,
|
||||
(AsyncLLMEngine,
|
||||
MQLLMEngineClient)), \
|
||||
"Beam search is only supported with" \
|
||||
"AsyncLLMEngine and MQLLMEngineClient."
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'],
|
||||
request_id,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
use_reasoning = self.reasoning_parser_cls is not None
|
||||
|
||||
all_previous_token_ids: Optional[List[List[int]]]
|
||||
# previous_texts / all_previous_token_ids are needed for both tool
|
||||
# parsing and reasoning parsing (both require full-history context).
|
||||
if tool_choice_auto or use_reasoning:
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
# Prepare the tool parser if it's needed
|
||||
try:
|
||||
if tool_choice_auto and self.tool_parser:
|
||||
tool_parsers: List[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer)
|
||||
] * num_choices
|
||||
else:
|
||||
tool_parsers = [None] * num_choices
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Prepare reasoning parsers (one instance per choice for state isolation)
|
||||
reasoning_parsers: List[Optional[object]] = [None] * num_choices
|
||||
reasoning_end_arr: List[bool] = [False] * num_choices
|
||||
reasoning_token_counts: List[int] = [0] * num_choices
|
||||
if use_reasoning:
|
||||
try:
|
||||
reasoning_parsers = [
|
||||
self.reasoning_parser_cls(
|
||||
tokenizer,
|
||||
chat_template_kwargs=request.chat_template_kwargs)
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
# If thinking is disabled per-request, mark reasoning as
|
||||
# already ended so the tool-auto branch is reachable.
|
||||
for idx, rp in enumerate(reasoning_parsers):
|
||||
if hasattr(rp, 'thinking_enabled') and not rp.thinking_enabled:
|
||||
reasoning_end_arr[idx] = True
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in reasoning parser creation: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
if first_iteration:
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
tool_parser = tool_parsers[i]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
role=role,
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# if usage should be included
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content: str = ""
|
||||
if conversation and "content" in conversation[
|
||||
-1] and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options and
|
||||
request.stream_options.include_usage):
|
||||
if (request.stream_options.
|
||||
continuous_usage_stats):
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(
|
||||
exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# Maintain text/token history when either reasoning or
|
||||
# auto-tool parsing is active.
|
||||
assert previous_texts is not None or not (
|
||||
tool_choice_auto or use_reasoning)
|
||||
if previous_texts is not None:
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + list(
|
||||
output.token_ids)
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
else:
|
||||
previous_text = ""
|
||||
previous_token_ids = []
|
||||
current_text = delta_text
|
||||
current_token_ids = list(output.token_ids)
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if tool_choice_function_name:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle reasoning: route through reasoning parser while
|
||||
# </think> has not yet been seen.
|
||||
elif use_reasoning and not reasoning_end_arr[i]:
|
||||
r_parser = reasoning_parsers[i]
|
||||
delta_message = r_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
)
|
||||
# Mark reasoning as ended when end token appears.
|
||||
if r_parser.end_token_id in current_token_ids:
|
||||
reasoning_end_arr[i] = True
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
# (only reached after reasoning block, if any, has ended)
|
||||
elif tool_choice_auto:
|
||||
assert tool_parser is not None
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk.
|
||||
# However, if this is the finish token we must NOT skip —
|
||||
# the finish block updates reasoning_token_counts, sets
|
||||
# finish_reason_sent, and flushes the final usage chunk.
|
||||
if delta_message is None:
|
||||
if output.finish_reason is None:
|
||||
continue
|
||||
delta_message = DeltaMessage()
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using guided decoding
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(
|
||||
tool_parser.prev_tool_call_arr) > 0
|
||||
index = len(tool_parser.prev_tool_call_arr
|
||||
) - 1 if auto_tools_called else 0
|
||||
else:
|
||||
index = 0
|
||||
|
||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output) and tool_parser:
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1)
|
||||
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
# Count reasoning tokens for this choice at finish time.
|
||||
if use_reasoning and all_previous_token_ids is not None:
|
||||
r_parser = reasoning_parsers[i]
|
||||
reasoning_token_counts[i] = \
|
||||
r_parser.count_reasoning_tokens(
|
||||
all_previous_token_ids[i])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
reasoning_tokens=total_reasoning,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
num_completion_tokens = sum(previous_num_tokens)
|
||||
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||
reasoning_tokens=total_reasoning)
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error("error in chat completion stream generator: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = False
|
||||
|
||||
# Extract reasoning content if parser is configured.
|
||||
# output_text is what remains after stripping <think>...</think>.
|
||||
reasoning_text: Optional[str] = None
|
||||
output_text: str = output.text
|
||||
if self.reasoning_parser_cls:
|
||||
r_parser = self.reasoning_parser_cls(
|
||||
tokenizer,
|
||||
chat_template_kwargs=request.chat_template_kwargs)
|
||||
reasoning_text, extracted = r_parser.extract_reasoning(
|
||||
output.text, request)
|
||||
output_text = extracted or ""
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content=output_text)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output_text))
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content=output_text)
|
||||
|
||||
# handle when there are tools and tool choice is auto
|
||||
elif request.tools and (
|
||||
request.tool_choice == "auto"
|
||||
or request.tool_choice is None) and self.enable_auto_tools \
|
||||
and self.tool_parser:
|
||||
|
||||
try:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Parse tool calls from the post-reasoning content.
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
output_text, request=request)
|
||||
auto_tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls)
|
||||
else:
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content=output_text)
|
||||
|
||||
# undetermined case that is still important to handle
|
||||
else:
|
||||
logger.error(
|
||||
"Error in chat_completion_full_generator - cannot determine"
|
||||
" if tools should be extracted. Returning a standard chat "
|
||||
"completion.")
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_text,
|
||||
content=output_text)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content = ""
|
||||
if conversation and "content" in conversation[-1] and conversation[
|
||||
-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + (choice.message.content
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
total_reasoning_tokens: Optional[int] = None
|
||||
if self.reasoning_parser_cls:
|
||||
rp = self.reasoning_parser_cls(
|
||||
tokenizer,
|
||||
chat_template_kwargs=request.chat_template_kwargs)
|
||||
total_reasoning_tokens = sum(
|
||||
rp.count_reasoning_tokens(list(output.token_ids))
|
||||
for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
reasoning_tokens=total_reasoning_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: AnyTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
step_decoded = step_token.decoded_token
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
self.return_tokens_as_token_ids,
|
||||
),
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
bytes=None if step_decoded is None else list(
|
||||
step_decoded.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs,
|
||||
num_output_top_logprobs,
|
||||
tokenizer,
|
||||
),
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
def _should_stream_with_auto_tool_parsing(self,
|
||||
request: ChatCompletionRequest):
|
||||
"""
|
||||
Utility function to check if streamed tokens should go through the tool
|
||||
call parser that was configured.
|
||||
|
||||
We only want to do this IF user-provided tools are set, a tool parser
|
||||
is configured, "auto" tool choice is enabled, and the request's tool
|
||||
choice field indicates that "auto" tool choice should be used.
|
||||
"""
|
||||
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||
and request.tool_choice in ['auto', None])
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: Optional[DeltaMessage],
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
Reference in New Issue
Block a user