Compare commits
10 Commits
2d1ef50992
...
72aa7e690a
| Author | SHA1 | Date | |
|---|---|---|---|
| 72aa7e690a | |||
| b5806731e0 | |||
| 47a4d9e72a | |||
| 3b8a567e9e | |||
| 50e3a05fb0 | |||
| 629f878c28 | |||
| 365da18436 | |||
| 4ab36b51d5 | |||
| d972854fb7 | |||
| c2de1c83b0 |
132
README.md
132
README.md
@@ -1,118 +1,46 @@
|
|||||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化)
|
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-35B-A3B)
|
||||||
|
|
||||||
本项目是为**天数智芯-天垓100**加速卡深度优化的高性能文本生成推理引擎,基于开源 **vLLM** 框架进行架构级适配与增强,率先实现对 **Qwen3 系列**等最新大模型的高效支持。通过引入 **Prefix Caching**、PagedAttention 等先进优化技术,显著提升吞吐与响应速度,同时提供标准 **OpenAI 兼容 API 接口**,便于无缝集成现有应用生态。
|
|
||||||
|
|
||||||
## 支持模型
|
|
||||||
|
|
||||||
- **Qwen3**
|
|
||||||
- **Llama3**
|
|
||||||
- **DeepSeek-R1-Distill**
|
|
||||||
- 其他兼容 vLLM 的 HuggingFace 模型(持续扩展中)
|
|
||||||
|
|
||||||
> 模型下载地址:[https://modelscope.cn/models/Qwen](https://modelscope.cn/models/Qwen)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### 1. 模型下载
|
|
||||||
|
|
||||||
从 ModelScope 下载所需模型(以 Qwen2.5-7B-Instruct 为例):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
modelscope download --model qwen/Qwen2.5-7B-Instruct README.md --local_dir /mnt/models/Qwen2.5-7B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
> ⚠️ 请确保模型路径在后续 Docker 启动时正确挂载。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2. 拉取并构建 Docker 镜像
|
|
||||||
|
|
||||||
我们提供已预装天垓100驱动与vLLM优化版本的Docker镜像:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
# 本地构建
|
# 本地构建
|
||||||
docker build -t enginex-iluvatar-vllm:bi100 -f Dockerfile .
|
docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile .
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 3. 启动服务容器
|
启动容器镜像
|
||||||
|
|
||||||
```bash
|
下载Qwen3.6-35B-A3B模型,并且需要将模型的config.json文件中architectures字段改成
|
||||||
docker run -it --rm -p 8000:80 \
|
```json
|
||||||
--name vllm-iluvatar \
|
"architectures": [
|
||||||
-v /mnt/models/Qwen2.5-7B-Instruct:/model:ro \
|
"Qwen3_5MoeForCausalLM"
|
||||||
--privileged \
|
]
|
||||||
-e TENSOR_PARALLEL_SIZE=1 \
|
|
||||||
-e PREFIX_CACHING=true \
|
|
||||||
-e MAX_MODEL_LEN=10000 \
|
|
||||||
enginex-iluvatar-vllm:bi100
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> ✅ 参数说明:
|
|
||||||
> - `PREFIX_CACHING=true`: 启用 Prefix Caching 优化,显著提升多请求共享前缀的推理效率
|
|
||||||
> - `MAX_MODEL_LEN=10000`: 支持长上下文推理
|
|
||||||
> - `--privileged`: 确保天垓100设备可见
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 测试服务(使用 OpenAI 兼容接口)
|
|
||||||
|
|
||||||
服务启动后,可通过标准 OpenAI SDK 或 `curl` 进行测试。
|
|
||||||
|
|
||||||
### 示例:文本生成请求
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:8000/v1/chat/completions \
|
docker run -dit --network=host --ipc=host \
|
||||||
|
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \
|
||||||
|
-v /mnt/disk1/models/Qwen3.6-35B-A3B:/model:ro --entrypoint=python3 \
|
||||||
|
-e CUDA_VISIBLE_DEVICES=4,5,6,7 -e VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 \
|
||||||
|
enginex-iluvatar-vllm:bi100-qwen3.6 \
|
||||||
|
-m vllm.entrypoints.openai.api_server \
|
||||||
|
--model /model --port 1111 --served-model-name llm \
|
||||||
|
--max-model-len 100000 --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
|
||||||
|
--max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
||||||
|
--max-num-batched-tokens 4096 --enable-chunked-prefill \
|
||||||
|
--max-seq-len-to-capture 32768 --enable-auto-tool-choice \
|
||||||
|
--tool-call-parser qwen3_coder --reasoning-parser qwen3
|
||||||
|
```
|
||||||
|
|
||||||
|
请求
|
||||||
|
```bash
|
||||||
|
curl http://localhost:1111/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "qwen3-8b",
|
"model": "llm",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "请用中文介绍一下上海的特点。"}
|
{"role": "user", "content": "Can you tell me the story of Snow White?"}
|
||||||
],
|
],
|
||||||
"temperature": 0.7,
|
"max_tokens": 200,
|
||||||
"max_tokens": 512
|
"temperature": 0.7
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### 使用 OpenAI Python SDK(需安装 `openai>=1.0`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
client = OpenAI(base_url="http://localhost:8000/v1", api_key="none")
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="qwen3-8b",
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "请简要介绍杭州的特色文化。"}
|
|
||||||
],
|
|
||||||
max_tokens=512,
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response.choices[0].message.content)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 测试结果对比(A100 vs 天垓100)
|
|
||||||
|
|
||||||
### 测试数据集
|
|
||||||
|
|
||||||
[chat_dataset_v0.json](chat_dataset_v0.json)
|
|
||||||
|
|
||||||
### 测试结果
|
|
||||||
|
|
||||||
在相同模型和输入条件下,测试平均输出速度(单位:字每秒),结果如下:
|
|
||||||
|
|
||||||
| 模型 | 天垓100 输出速度 | Nvidia A100 输出速度 |
|
|
||||||
|--------|--------------------------|-------------------------------|
|
|
||||||
| Qwen2.5-7B-Instruct | 36.8 | 112.4 |
|
|
||||||
| Qwen2.5-1.5B-Instruct-AWQ | 72.4 | 100.8 |
|
|
||||||
| Qwen/Qwen1.5-32B-Chat | 12.4 | 55.7 |
|
|
||||||
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
# 天数智芯 天垓100 文本生成引擎(基于 vLLM 优化适配Qwen3.6-27B)
|
|
||||||
|
|
||||||
```
|
|
||||||
# 本地构建
|
|
||||||
docker build -t enginex-iluvatar-vllm:bi100-qwen3.6 -f Dockerfile .
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
启动容器镜像
|
|
||||||
|
|
||||||
下载Qwen3.6-27B模型,并且需要将模型的config.json文件中architectures字段改成
|
|
||||||
```json
|
|
||||||
"architectures": [
|
|
||||||
"Qwen3_5ForCausalLM"
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run -dit --network=host --ipc=host \
|
|
||||||
-v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev --privileged \
|
|
||||||
--name vllm-iluvatar \
|
|
||||||
-v /mnt/models/Qwen3.6-27B:/model:ro --entrypoint=python3 \
|
|
||||||
enginex-iluvatar-vllm:bi100 \
|
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model /model --port 1111 --served-model-name llm \
|
|
||||||
--max-model-len 10000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95
|
|
||||||
```
|
|
||||||
|
|
||||||
请求
|
|
||||||
```bash
|
|
||||||
curl http://localhost:1111/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "llm",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Can you tell me the story of Snow White?"}
|
|
||||||
],
|
|
||||||
"max_tokens": 200,
|
|
||||||
"temperature": 0.7
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
33
computility-run.yaml
Normal file
33
computility-run.yaml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
gpu_num: 4
|
||||||
|
command:
|
||||||
|
- python3
|
||||||
|
- -m
|
||||||
|
- vllm.entrypoints.openai.api_server
|
||||||
|
- --model
|
||||||
|
- /model
|
||||||
|
- --served-model-name
|
||||||
|
- llm
|
||||||
|
- --max-model-len
|
||||||
|
- '100000'
|
||||||
|
- --gpu-memory-utilization
|
||||||
|
- '0.95'
|
||||||
|
- --trust-remote-code
|
||||||
|
- -tp
|
||||||
|
- '4'
|
||||||
|
- --max-num-seqs
|
||||||
|
- '1'
|
||||||
|
- --disable-log-requests
|
||||||
|
- --disable-frontend-multiprocessing
|
||||||
|
- --max-num-batched-tokens
|
||||||
|
- '4096'
|
||||||
|
- --enable-chunked-prefill
|
||||||
|
- --max-seq-len-to-capture
|
||||||
|
- '32768'
|
||||||
|
- --enable-auto-tool-choice
|
||||||
|
- --tool-call-parser
|
||||||
|
- qwen3_coder
|
||||||
|
- --reasoning-parser
|
||||||
|
- qwen3
|
||||||
|
env:
|
||||||
|
- name: VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
|
value: 3600
|
||||||
595
qwen3_6_scripts/api_server.py
Normal file
595
qwen3_6_scripts/api_server.py
Normal file
@@ -0,0 +1,595 @@
|
|||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import regex as re
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import tempfile
|
||||||
|
from argparse import Namespace
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import AsyncIterator, Set
|
||||||
|
|
||||||
|
import uvloop
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from starlette.datastructures import State
|
||||||
|
from starlette.routing import Mount
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.launcher import serve_http
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse, ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse,
|
||||||
|
UnloadLoraAdapterRequest)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
|
OpenAIServingTokenization)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
try:
|
||||||
|
if app.state.log_stats:
|
||||||
|
engine_client: EngineClient = app.state.engine_client
|
||||||
|
|
||||||
|
async def _force_log():
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(10.)
|
||||||
|
await engine_client.do_log_stats()
|
||||||
|
|
||||||
|
task = asyncio.create_task(_force_log())
|
||||||
|
_running_tasks.add(task)
|
||||||
|
task.add_done_callback(_running_tasks.remove)
|
||||||
|
else:
|
||||||
|
task = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if task is not None:
|
||||||
|
task.cancel()
|
||||||
|
finally:
|
||||||
|
# Ensure app state including engine ref is gc'd
|
||||||
|
del app.state
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client(
|
||||||
|
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||||
|
|
||||||
|
# Context manager to handle engine_client lifecycle
|
||||||
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
async with build_async_engine_client_from_engine_args(
|
||||||
|
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client_from_engine_args(
|
||||||
|
engine_args: AsyncEngineArgs,
|
||||||
|
disable_frontend_multiprocessing: bool = False,
|
||||||
|
) -> AsyncIterator[EngineClient]:
|
||||||
|
"""
|
||||||
|
Create EngineClient, either:
|
||||||
|
- in-process using the AsyncLLMEngine Directly
|
||||||
|
- multiprocess using AsyncLLMEngine RPC
|
||||||
|
|
||||||
|
Returns the Client or None if the creation failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Fall back
|
||||||
|
# TODO: fill out feature matrix.
|
||||||
|
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||||
|
or disable_frontend_multiprocessing):
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||||
|
"uses_ray", False)
|
||||||
|
|
||||||
|
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||||
|
engine_args=engine_args,
|
||||||
|
engine_config=engine_config,
|
||||||
|
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
|
if uses_ray:
|
||||||
|
# Must run in main thread with ray for its signal handlers to work
|
||||||
|
engine_client = build_engine()
|
||||||
|
else:
|
||||||
|
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, build_engine)
|
||||||
|
|
||||||
|
yield engine_client
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
|
else:
|
||||||
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||||
|
# Make TemporaryDirectory for prometheus multiprocessing
|
||||||
|
# Note: global TemporaryDirectory will be automatically
|
||||||
|
# cleaned up upon exit.
|
||||||
|
global prometheus_multiproc_dir
|
||||||
|
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||||
|
os.environ[
|
||||||
|
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||||
|
"This directory must be wiped between vLLM runs or "
|
||||||
|
"you will find inaccurate metrics. Unset the variable "
|
||||||
|
"and vLLM will properly handle cleanup.")
|
||||||
|
|
||||||
|
# Select random path for IPC.
|
||||||
|
ipc_path = get_open_zmq_ipc_path()
|
||||||
|
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||||
|
ipc_path)
|
||||||
|
|
||||||
|
# Start RPCServer in separate process (holds the LLMEngine).
|
||||||
|
# the current process might have CUDA context,
|
||||||
|
# so we need to spawn a new process
|
||||||
|
context = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
engine_process = context.Process(target=run_mp_engine,
|
||||||
|
args=(engine_args,
|
||||||
|
UsageContext.OPENAI_API_SERVER,
|
||||||
|
ipc_path))
|
||||||
|
engine_process.start()
|
||||||
|
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||||
|
|
||||||
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
|
# NOTE: Actually, this is not true yet. We still need to support
|
||||||
|
# embedding models via RPC (see TODO above)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await mp_engine_client.setup()
|
||||||
|
break
|
||||||
|
except TimeoutError:
|
||||||
|
if not engine_process.is_alive():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Engine process failed to start") from None
|
||||||
|
|
||||||
|
yield mp_engine_client # type: ignore[misc]
|
||||||
|
finally:
|
||||||
|
# Ensure rpc server process was terminated
|
||||||
|
engine_process.terminate()
|
||||||
|
|
||||||
|
# Close all open connections to the backend
|
||||||
|
mp_engine_client.close()
|
||||||
|
|
||||||
|
# Wait for engine process to join
|
||||||
|
engine_process.join(4)
|
||||||
|
if engine_process.exitcode is None:
|
||||||
|
# Kill if taking longer than 5 seconds to stop
|
||||||
|
engine_process.kill()
|
||||||
|
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import multiprocess
|
||||||
|
multiprocess.mark_process_dead(engine_process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def mount_metrics(app: FastAPI):
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||||
|
multiprocess)
|
||||||
|
|
||||||
|
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
if prometheus_multiproc_dir_path is not None:
|
||||||
|
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||||
|
prometheus_multiproc_dir_path)
|
||||||
|
registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
else:
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app())
|
||||||
|
|
||||||
|
# Workaround for 307 Redirect for /metrics
|
||||||
|
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||||
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
|
def chat(request: Request) -> OpenAIServingChat:
|
||||||
|
return request.app.state.openai_serving_chat
|
||||||
|
|
||||||
|
|
||||||
|
def completion(request: Request) -> OpenAIServingCompletion:
|
||||||
|
return request.app.state.openai_serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||||
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health(raw_request: Request) -> Response:
|
||||||
|
"""Health check."""
|
||||||
|
await engine_client(raw_request).check_health()
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokenize")
|
||||||
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
|
generator = await tokenization(raw_request).create_tokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, TokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detokenize")
|
||||||
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
|
generator = await tokenization(raw_request).create_detokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, DetokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models")
|
||||||
|
async def show_available_models(raw_request: Request):
|
||||||
|
models = await completion(raw_request).show_available_models()
|
||||||
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/version")
|
||||||
|
async def show_version():
|
||||||
|
ver = {"version": VLLM_VERSION}
|
||||||
|
return JSONResponse(content=ver)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
|
||||||
|
generator = await chat(raw_request).create_chat_completion(
|
||||||
|
request, raw_request)
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
|
||||||
|
elif isinstance(generator, ChatCompletionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/completions")
|
||||||
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
|
generator = await completion(raw_request).create_completion(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, CompletionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/embeddings")
|
||||||
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
generator = await embedding(raw_request).create_embedding(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, EmbeddingResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
logger.warning(
|
||||||
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
"used for local development!")
|
||||||
|
|
||||||
|
@router.post("/start_profile")
|
||||||
|
async def start_profile(raw_request: Request):
|
||||||
|
logger.info("Starting profiler...")
|
||||||
|
await engine_client(raw_request).start_profile()
|
||||||
|
logger.info("Profiler started.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.post("/stop_profile")
|
||||||
|
async def stop_profile(raw_request: Request):
|
||||||
|
logger.info("Stopping profiler...")
|
||||||
|
await engine_client(raw_request).stop_profile()
|
||||||
|
logger.info("Profiler stopped.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
logger.warning(
|
||||||
|
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||||
|
"This should ONLY be used for local development!")
|
||||||
|
|
||||||
|
@router.post("/v1/load_lora_adapter")
|
||||||
|
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await completion(raw_request).load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
@router.post("/v1/unload_lora_adapter")
|
||||||
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await completion(raw_request).unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
|
if args.disable_fastapi_docs:
|
||||||
|
app = FastAPI(openapi_url=None,
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan)
|
||||||
|
else:
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(router)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
|
||||||
|
mount_metrics(app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=args.allowed_origins,
|
||||||
|
allow_credentials=args.allow_credentials,
|
||||||
|
allow_methods=args.allowed_methods,
|
||||||
|
allow_headers=args.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(_, exc):
|
||||||
|
chat = app.state.openai_serving_chat
|
||||||
|
err = chat.create_error_response(message=str(exc))
|
||||||
|
return JSONResponse(err.model_dump(),
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
if token := envs.VLLM_API_KEY or args.api_key:
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def authentication(request: Request, call_next):
|
||||||
|
root_path = "" if args.root_path is None else args.root_path
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||||
|
return await call_next(request)
|
||||||
|
if request.headers.get("Authorization") != "Bearer " + token:
|
||||||
|
return JSONResponse(content={"error": "Unauthorized"},
|
||||||
|
status_code=401)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
for middleware in args.middleware:
|
||||||
|
module_path, object_name = middleware.rsplit(".", 1)
|
||||||
|
imported = getattr(importlib.import_module(module_path), object_name)
|
||||||
|
if inspect.isclass(imported):
|
||||||
|
app.add_middleware(imported)
|
||||||
|
elif inspect.iscoroutinefunction(imported):
|
||||||
|
app.middleware("http")(imported)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid middleware {middleware}. "
|
||||||
|
f"Must be a function or a class.")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def init_app_state(
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
state: State,
|
||||||
|
args: Namespace,
|
||||||
|
) -> None:
|
||||||
|
if args.served_model_name is not None:
|
||||||
|
served_model_names = args.served_model_name
|
||||||
|
else:
|
||||||
|
served_model_names = [args.model]
|
||||||
|
|
||||||
|
if args.disable_log_requests:
|
||||||
|
request_logger = None
|
||||||
|
else:
|
||||||
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
|
base_model_paths = [
|
||||||
|
BaseModelPath(name=name, model_path=args.model)
|
||||||
|
for name in served_model_names
|
||||||
|
]
|
||||||
|
|
||||||
|
state.engine_client = engine_client
|
||||||
|
state.log_stats = not args.disable_log_stats
|
||||||
|
|
||||||
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
args.response_role,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
|
tool_parser=args.tool_call_parser,
|
||||||
|
reasoning_parser=getattr(args, 'reasoning_parser', None))
|
||||||
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
request_logger=request_logger,
|
||||||
|
)
|
||||||
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
|
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||||
|
if args.enable_auto_tool_choice \
|
||||||
|
and args.tool_call_parser not in valide_tool_parses:
|
||||||
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||||
|
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||||
|
|
||||||
|
reasoning_parser = getattr(args, 'reasoning_parser', None)
|
||||||
|
if reasoning_parser:
|
||||||
|
valid_reasoning = ReasoningParserManager.list_registered()
|
||||||
|
if reasoning_parser not in valid_reasoning:
|
||||||
|
raise KeyError(
|
||||||
|
f"invalid reasoning parser: {reasoning_parser} "
|
||||||
|
f"(chose from {{ {','.join(valid_reasoning)} }})")
|
||||||
|
|
||||||
|
# workaround to make sure that we bind the port before the engine is set up.
|
||||||
|
# This avoids race conditions with ray.
|
||||||
|
# see https://github.com/vllm-project/vllm/issues/8204
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.bind(("", args.port))
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
# Interrupt server on sigterm while initializing
|
||||||
|
raise KeyboardInterrupt("terminated")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
async with build_async_engine_client(args) as engine_client:
|
||||||
|
app = build_app(args)
|
||||||
|
|
||||||
|
model_config = await engine_client.get_model_config()
|
||||||
|
init_app_state(engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
|
shutdown_task = await serve_http(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level=args.uvicorn_log_level,
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
fd=sock.fileno(),
|
||||||
|
**uvicorn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Await server shutdown only after the backend context is exited
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NOTE(simon):
|
||||||
|
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
uvloop.run(run_server(args))
|
||||||
601
qwen3_6_scripts/chat_utils.py
Normal file
601
qwen3_6_scripts/chat_utils.py
Normal file
@@ -0,0 +1,601 @@
|
|||||||
|
import asyncio
|
||||||
|
import codecs
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||||
|
Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionContentPartImageParam)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||||
|
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||||
|
ChatCompletionContentPartTextParam)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||||
|
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionToolMessageParam)
|
||||||
|
# yapf: enable
|
||||||
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from typing_extensions import Required, TypeAlias, TypedDict
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||||
|
async_get_and_parse_image,
|
||||||
|
get_and_parse_audio, get_and_parse_image)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioURL(TypedDict, total=False):
|
||||||
|
url: Required[str]
|
||||||
|
"""
|
||||||
|
Either a URL of the audio or a data URL with base64 encoded audio data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||||
|
audio_url: Required[AudioURL]
|
||||||
|
|
||||||
|
type: Required[Literal["audio_url"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||||
|
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||||
|
|
||||||
|
type: Required[str]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||||
|
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||||
|
ChatCompletionContentPartRefusalParam,
|
||||||
|
CustomChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
|
"""Enables custom roles in the Chat Completion API."""
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||||
|
"""The contents of the message."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the
|
||||||
|
same role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||||
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
reasoning_content: Optional[str]
|
||||||
|
"""Reasoning / thinking content for assistant messages (vLLM extension).
|
||||||
|
When present in a previous assistant turn, it is rendered as
|
||||||
|
<think>...</think> before the main content so the model sees its own
|
||||||
|
chain-of-thought in subsequent turns."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||||
|
CustomChatCompletionMessageParam]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Make fields ReadOnly once mypy supports it
|
||||||
|
class ConversationMessage(TypedDict, total=False):
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Optional[str]
|
||||||
|
"""The contents of the message"""
|
||||||
|
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
name: Optional[str]
|
||||||
|
"""The name of the function to call"""
|
||||||
|
|
||||||
|
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||||
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
reasoning_content: Optional[str]
|
||||||
|
"""Reasoning / thinking content for assistant messages.
|
||||||
|
Passed directly to the chat template (Qwen3 reads message.reasoning_content
|
||||||
|
natively) instead of being manually wrapped in <think>...</think>."""
|
||||||
|
|
||||||
|
|
||||||
|
ModalityStr = Literal["image", "audio", "video"]
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||||
|
"""
|
||||||
|
Tracks multi-modal items in a given request and ensures that the number
|
||||||
|
of multi-modal items in a given request does not exceed the configured
|
||||||
|
maximum per prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._model_config = model_config
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||||
|
if model_config.multimodal_config else {})
|
||||||
|
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||||
|
|
||||||
|
self._items: List[_T] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||||
|
return tokenizer.decode(token_index)
|
||||||
|
|
||||||
|
def _placeholder_str(self, modality: ModalityStr,
|
||||||
|
current_count: int) -> Optional[str]:
|
||||||
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
hf_config = self._model_config.hf_config
|
||||||
|
model_type = hf_config.model_type
|
||||||
|
|
||||||
|
if modality == "image":
|
||||||
|
if model_type == "phi3_v":
|
||||||
|
# Workaround since this token is not defined in the tokenizer
|
||||||
|
return f"<|image_{current_count}|>"
|
||||||
|
if model_type == "minicpmv":
|
||||||
|
return "(<image>./</image>)"
|
||||||
|
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||||
|
"pixtral"):
|
||||||
|
# These models do not use image tokens in the prompt
|
||||||
|
return None
|
||||||
|
if model_type == "qwen":
|
||||||
|
return f"Picture {current_count}: <img></img>"
|
||||||
|
if model_type.startswith("llava"):
|
||||||
|
return self._cached_token_str(self._tokenizer,
|
||||||
|
hf_config.image_token_index)
|
||||||
|
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||||
|
return "<image>"
|
||||||
|
if model_type == "mllama":
|
||||||
|
return "<|image|>"
|
||||||
|
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||||
|
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
if model_type == "molmo":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "audio":
|
||||||
|
if model_type == "ultravox":
|
||||||
|
return "<|reserved_special_token_0|>"
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "video":
|
||||||
|
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||||
|
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||||
|
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||||
|
|
||||||
|
# Merge all the multi-modal items
|
||||||
|
for single_mm_data in items:
|
||||||
|
for mm_key, mm_item in single_mm_data.items():
|
||||||
|
if isinstance(mm_item, list):
|
||||||
|
mm_lists[mm_key].extend(mm_item)
|
||||||
|
else:
|
||||||
|
mm_lists[mm_key].append(mm_item)
|
||||||
|
|
||||||
|
# Unpack any single item lists for models that don't expect multiple.
|
||||||
|
return {
|
||||||
|
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||||
|
for mm_key, mm_list in mm_lists.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Add a multi-modal item to the current prompt and returns the
|
||||||
|
placeholder string to use, if any.
|
||||||
|
"""
|
||||||
|
allowed_count = self._allowed_items.get(modality, 1)
|
||||||
|
current_count = self._consumed_items.get(modality, 0) + 1
|
||||||
|
if current_count > allowed_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||||
|
"one request.")
|
||||||
|
|
||||||
|
self._consumed_items[modality] = current_count
|
||||||
|
self._items.append(item)
|
||||||
|
|
||||||
|
return self._placeholder_str(modality, current_count)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||||
|
|
||||||
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
return self._combine(self._items) if self._items else None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return MultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalItemTracker(
|
||||||
|
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||||
|
|
||||||
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
if self._items:
|
||||||
|
items = await asyncio.gather(*self._items)
|
||||||
|
return self._combine(items)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return AsyncMultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalContentParser(ABC):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# multimodal placeholder_string : count
|
||||||
|
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
def _add_placeholder(self, placeholder: Optional[str]):
|
||||||
|
if placeholder:
|
||||||
|
self._placeholder_counts[placeholder] += 1
|
||||||
|
|
||||||
|
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||||
|
return dict(self._placeholder_counts)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image = get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio = get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image_coro = async_get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio_coro = async_get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||||
|
"""Raises if the provided chat template appears invalid."""
|
||||||
|
if chat_template is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"the supplied chat template path doesn't exist")
|
||||||
|
|
||||||
|
elif isinstance(chat_template, str):
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template
|
||||||
|
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"The supplied chat template string ({chat_template}) "
|
||||||
|
f"appears path-like, but doesn't exist!")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{type(chat_template)} is not a valid chat template type")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(
|
||||||
|
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||||
|
if chat_template is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
resolved_chat_template = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
if isinstance(chat_template, Path):
|
||||||
|
raise
|
||||||
|
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}")
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||||
|
|
||||||
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||||
|
return resolved_chat_template
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||||
|
text_prompt: str) -> str:
|
||||||
|
"""Combine multimodal prompts for a multimodal language model."""
|
||||||
|
|
||||||
|
# Look through the text prompt to check for missing placeholders
|
||||||
|
missing_placeholders: List[str] = []
|
||||||
|
for placeholder in placeholder_counts:
|
||||||
|
|
||||||
|
# For any existing placeholder in the text prompt, we leave it as is
|
||||||
|
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||||
|
|
||||||
|
if placeholder_counts[placeholder] < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||||
|
"actual multimodal data items.")
|
||||||
|
|
||||||
|
missing_placeholders.extend([placeholder] *
|
||||||
|
placeholder_counts[placeholder])
|
||||||
|
|
||||||
|
# NOTE: For now we always add missing placeholders at the front of
|
||||||
|
# the prompt. This may change to be customizable in the future.
|
||||||
|
return "\n".join(missing_placeholders + [text_prompt])
|
||||||
|
|
||||||
|
|
||||||
|
# No need to validate using Pydantic again
|
||||||
|
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||||
|
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||||
|
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||||
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
|
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_message_content_parts(
|
||||||
|
role: str,
|
||||||
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
texts: List[str] = []
|
||||||
|
|
||||||
|
mm_parser = mm_tracker.create_parser()
|
||||||
|
keep_multimodal_content = \
|
||||||
|
mm_tracker._model_config.hf_config.model_type in \
|
||||||
|
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||||
|
|
||||||
|
has_image = False
|
||||||
|
for part in parts:
|
||||||
|
part_type = part["type"]
|
||||||
|
if part_type == "text":
|
||||||
|
text = _TextParser(part)["text"]
|
||||||
|
texts.append(text)
|
||||||
|
elif part_type == "image_url":
|
||||||
|
image_url = _ImageParser(part)["image_url"]
|
||||||
|
|
||||||
|
if image_url.get("detail", "auto") != "auto":
|
||||||
|
logger.warning(
|
||||||
|
"'image_url.detail' is currently not supported and "
|
||||||
|
"will be ignored.")
|
||||||
|
|
||||||
|
mm_parser.parse_image(image_url["url"])
|
||||||
|
has_image = True
|
||||||
|
elif part_type == "audio_url":
|
||||||
|
audio_url = _AudioParser(part)["audio_url"]
|
||||||
|
|
||||||
|
mm_parser.parse_audio(audio_url["url"])
|
||||||
|
elif part_type == "refusal":
|
||||||
|
text = _RefusalParser(part)["refusal"]
|
||||||
|
texts.append(text)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
|
text_prompt = "\n".join(texts)
|
||||||
|
if keep_multimodal_content:
|
||||||
|
text_prompt = "\n".join(texts)
|
||||||
|
role_content = [{'type': 'text', 'text': text_prompt}]
|
||||||
|
|
||||||
|
if has_image:
|
||||||
|
role_content = [{'type': 'image'}] + role_content
|
||||||
|
return [ConversationMessage(role=role,
|
||||||
|
content=role_content)] # type: ignore
|
||||||
|
else:
|
||||||
|
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||||
|
if mm_placeholder_counts:
|
||||||
|
text_prompt = _get_full_multimodal_text_prompt(
|
||||||
|
mm_placeholder_counts, text_prompt)
|
||||||
|
return [ConversationMessage(role=role, content=text_prompt)]
|
||||||
|
|
||||||
|
|
||||||
|
# No need to validate using Pydantic again
|
||||||
|
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
|
||||||
|
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_message_content(
|
||||||
|
message: ChatCompletionMessageParam,
|
||||||
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
role = message["role"]
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
content = []
|
||||||
|
elif isinstance(content, str):
|
||||||
|
content = [
|
||||||
|
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _parse_chat_message_content_parts(
|
||||||
|
role,
|
||||||
|
content, # type: ignore
|
||||||
|
mm_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result_msg in result:
|
||||||
|
if role == 'assistant':
|
||||||
|
parsed_msg = _AssistantParser(message)
|
||||||
|
|
||||||
|
if "tool_calls" in parsed_msg:
|
||||||
|
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||||
|
|
||||||
|
# Pass reasoning content as a dedicated field so the chat template
|
||||||
|
# can render it natively (Qwen3: message.reasoning_content branch).
|
||||||
|
# Accept both "reasoning" (new vllm) and "reasoning_content" (ours).
|
||||||
|
reasoning = (message.get("reasoning") # type: ignore[arg-type]
|
||||||
|
or message.get("reasoning_content")) # type: ignore[arg-type]
|
||||||
|
if reasoning and isinstance(reasoning, str):
|
||||||
|
result_msg["reasoning_content"] = reasoning
|
||||||
|
|
||||||
|
elif role == "tool":
|
||||||
|
parsed_msg = _ToolParser(message)
|
||||||
|
if "tool_call_id" in parsed_msg:
|
||||||
|
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
|
||||||
|
|
||||||
|
if "name" in message and isinstance(message["name"], str):
|
||||||
|
result_msg["name"] = message["name"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||||
|
# per the Transformers docs & maintainers, tool call arguments in
|
||||||
|
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||||
|
# this is how tool-use chat templates will expect them moving forwards
|
||||||
|
# so, for messages that have tool_calls, parse the string (which we get
|
||||||
|
# from openAI format) to dict
|
||||||
|
for message in messages:
|
||||||
|
if (message["role"] == "assistant" and "tool_calls" in message
|
||||||
|
and isinstance(message["tool_calls"], list)):
|
||||||
|
|
||||||
|
for item in message["tool_calls"]:
|
||||||
|
item["function"]["arguments"] = json.loads(
|
||||||
|
item["function"]["arguments"])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_messages(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||||
|
|
||||||
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_messages_futures(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||||
|
|
||||||
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hf_chat_template(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
*,
|
||||||
|
tokenize: bool = False, # Different from HF's default
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if chat_template is None and tokenizer.chat_template is None:
|
||||||
|
raise ValueError(
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one.")
|
||||||
|
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=tokenize,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mistral_chat_template(
|
||||||
|
tokenizer: MistralTokenizer,
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[int]:
|
||||||
|
if chat_template is not None:
|
||||||
|
logger.warning(
|
||||||
|
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||||
|
if "add_generation_prompt" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
if "continue_final_message" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages=messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
261
qwen3_6_scripts/cli_args.py
Normal file
261
qwen3_6_scripts/cli_args.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
This file contains the command line arguments for the vLLM's
|
||||||
|
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||||
|
purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import ssl
|
||||||
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
|
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
PromptAdapterPath)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
namespace: argparse.Namespace,
|
||||||
|
values: Optional[Union[str, Sequence[str]]],
|
||||||
|
option_string: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
|
if isinstance(values, str):
|
||||||
|
raise TypeError("Expected values to be a list")
|
||||||
|
|
||||||
|
lora_list: List[LoRAModulePath] = []
|
||||||
|
for item in values:
|
||||||
|
if item in [None, '']: # Skip if item is None or empty string
|
||||||
|
continue
|
||||||
|
if '=' in item and ',' not in item: # Old format: name=path
|
||||||
|
name, path = item.split('=')
|
||||||
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
|
else: # Assume JSON format
|
||||||
|
try:
|
||||||
|
lora_dict = json.loads(item)
|
||||||
|
lora = LoRAModulePath(**lora_dict)
|
||||||
|
lora_list.append(lora)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid JSON format for --lora-modules: {item}")
|
||||||
|
except TypeError as e:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||||
|
)
|
||||||
|
setattr(namespace, self.dest, lora_list)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptAdapterParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
namespace: argparse.Namespace,
|
||||||
|
values: Optional[Union[str, Sequence[str]]],
|
||||||
|
option_string: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
|
if isinstance(values, str):
|
||||||
|
raise TypeError("Expected values to be a list")
|
||||||
|
|
||||||
|
adapter_list: List[PromptAdapterPath] = []
|
||||||
|
for item in values:
|
||||||
|
name, path = item.split('=')
|
||||||
|
adapter_list.append(PromptAdapterPath(name, path))
|
||||||
|
setattr(namespace, self.dest, adapter_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
parser.add_argument("--host",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--uvicorn-log-level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||||
|
help="log level for uvicorn")
|
||||||
|
parser.add_argument("--allow-credentials",
|
||||||
|
action="store_true",
|
||||||
|
help="allow credentials")
|
||||||
|
parser.add_argument("--allowed-origins",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed origins")
|
||||||
|
parser.add_argument("--allowed-methods",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed methods")
|
||||||
|
parser.add_argument("--allowed-headers",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed headers")
|
||||||
|
parser.add_argument("--api-key",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="If provided, the server will require this key "
|
||||||
|
"to be presented in the header.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-modules",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=LoRAParserAction,
|
||||||
|
help="LoRA module configurations in either 'name=path' format"
|
||||||
|
"or JSON format. "
|
||||||
|
"Example (old format): 'name=path' "
|
||||||
|
"Example (new format): "
|
||||||
|
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||||
|
"\"base_model_name\": \"id\"}'")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-adapters",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=PromptAdapterParserAction,
|
||||||
|
help="Prompt adapter configurations in the format name=path. "
|
||||||
|
"Multiple adapters can be specified.")
|
||||||
|
parser.add_argument("--chat-template",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the chat template, "
|
||||||
|
"or the template in single-line form "
|
||||||
|
"for the specified model")
|
||||||
|
parser.add_argument("--response-role",
|
||||||
|
type=nullable_str,
|
||||||
|
default="assistant",
|
||||||
|
help="The role name to return if "
|
||||||
|
"`request.add_generation_prompt=true`.")
|
||||||
|
parser.add_argument("--ssl-keyfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL key file")
|
||||||
|
parser.add_argument("--ssl-certfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL cert file")
|
||||||
|
parser.add_argument("--ssl-ca-certs",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The CA certificates file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-cert-reqs",
|
||||||
|
type=int,
|
||||||
|
default=int(ssl.CERT_NONE),
|
||||||
|
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||||
|
parser.add_argument(
|
||||||
|
"--middleware",
|
||||||
|
type=nullable_str,
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Additional ASGI middleware to apply to the app. "
|
||||||
|
"We accept multiple --middleware arguments. "
|
||||||
|
"The value should be an import path. "
|
||||||
|
"If a function is provided, vLLM will add it to the server "
|
||||||
|
"using @app.middleware('http'). "
|
||||||
|
"If a class is provided, vLLM will add it to the server "
|
||||||
|
"using app.add_middleware(). ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-tokens-as-token-ids",
|
||||||
|
action="store_true",
|
||||||
|
help="When --max-logprobs is specified, represents single tokens as "
|
||||||
|
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||||
|
"are not JSON-encodable can be identified.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-frontend-multiprocessing",
|
||||||
|
action="store_true",
|
||||||
|
help="If specified, will run the OpenAI frontend server in the same "
|
||||||
|
"process as the model serving engine.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help=
|
||||||
|
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||||
|
"to specify which parser to use")
|
||||||
|
|
||||||
|
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-call-parser",
|
||||||
|
type=str,
|
||||||
|
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||||
|
"--tool-parser-plugin",
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the tool call parser depending on the model that you're using."
|
||||||
|
" This is used to parse the model-generated tool call into OpenAI API "
|
||||||
|
"format. Required for --enable-auto-tool-choice.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-parser-plugin",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=
|
||||||
|
"Special the tool parser plugin write to parse the model-generated tool"
|
||||||
|
" into OpenAI API format, the name register in this plugin can be used "
|
||||||
|
"in --tool-call-parser.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reasoning-parser",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the reasoning parser to split <think>...</think> content into "
|
||||||
|
"reasoning_content vs content in the response. "
|
||||||
|
"Supported: qwen3")
|
||||||
|
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log.'
|
||||||
|
'\n\nDefault: Unlimited')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-fastapi-docs",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||||
|
"""Quick checks for model serve args that raise prior to loading."""
|
||||||
|
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure that the chat template is valid; raises if it likely isn't
|
||||||
|
validate_chat_template(args.chat_template)
|
||||||
|
|
||||||
|
# Enable auto tool needs a tool call parser to be valid
|
||||||
|
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||||
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
|
"--tool-call-parser")
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
|
prog="-m vllm.entrypoints.openai.api_server")
|
||||||
|
return make_arg_parser(parser_for_docs)
|
||||||
@@ -309,28 +309,23 @@ class PagedAttention:
|
|||||||
seq_lens_tensor: torch.Tensor,
|
seq_lens_tensor: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pure-PyTorch prefix-attention with query-chunking (no Triton).
|
"""Pure-PyTorch prefix-attention with K-tiling (Flash-Attention online softmax).
|
||||||
|
|
||||||
For each sequence, gathers the context KV from the paged KV cache,
|
Memory complexity: O(q_len), independent of kv_len.
|
||||||
concatenates with the current-chunk K/V, then computes scaled-dot-
|
With chunked prefill (q_len ≤ max_num_batched_tokens = 4096) peak
|
||||||
product attention with a causal mask.
|
per layer ≈ 96 MB regardless of context length.
|
||||||
|
|
||||||
Memory optimisation — GQA-aware Q-tiling
|
Algorithm: Flash Attention online softmax.
|
||||||
-----------------------------------------
|
Q is reshaped once to [kv_h, gqa, q_len, d] (24 MB) and held for all
|
||||||
Two complementary tricks keep peak activation memory well below 1 GB
|
K-tiles. For each tile a running (m, l, o) accumulator is updated —
|
||||||
even for 100K context on TP=4 (kv_h=1, q_h=6):
|
the [q_len × kv_len] attention matrix is NEVER materialised in full.
|
||||||
|
|
||||||
1. No GQA pre-expansion: K/V are kept at native [kv_h, kv_len, d]
|
Tile budget (kv_h=1, gqa=6, q_len=4096, tile=256 tokens):
|
||||||
resolution and GQA grouping is handled via 4D reshape+broadcast
|
q_seq [1, 6, 4096, 256] fp32 24 MB (held all tiles)
|
||||||
inside the matmul. With kv_h=1 and kv_len=100K this saves ~6×
|
o_acc same shape 24 MB (held all tiles)
|
||||||
vs the old expand-then-float32 approach:
|
s same shape 24 MB (per tile, freed before exp_s)
|
||||||
Old: [6, 100K, 256] fp32 = 586 MB each for K and V
|
exp_s same shape 24 MB (per tile, brief overlap with s)
|
||||||
New: [1, 100K, 256] fp32 = 98 MB each for K and V
|
Peak ≈ 96 MB (s and exp_s briefly coexist during update).
|
||||||
|
|
||||||
2. Q-tiling (_ATTN_Q_CHUNK=64): attn_w [kv_h, gqa, Q, kv_len] fp32
|
|
||||||
is bounded to ~148 MB at 100K instead of growing with q_len.
|
|
||||||
|
|
||||||
Combined peak per layer (100K): ~352 MB vs ~1200 MB previously.
|
|
||||||
|
|
||||||
Shapes
|
Shapes
|
||||||
------
|
------
|
||||||
@@ -344,29 +339,24 @@ class PagedAttention:
|
|||||||
seq_lens_tensor: [batch_size] total length (context + query)
|
seq_lens_tensor: [batch_size] total length (context + query)
|
||||||
context_lens : [batch_size] tokens already in KV cache
|
context_lens : [batch_size] tokens already in KV cache
|
||||||
"""
|
"""
|
||||||
# Memory-efficient query-chunked attention.
|
|
||||||
# Key optimisation: do NOT expand KV heads for GQA before materialising
|
|
||||||
# k_t / v_t. With kv_h=1 (Qwen3.6 TP=4), keeping K/V at native kv_h
|
|
||||||
# resolution saves ~6× memory vs expanding to q_h first:
|
|
||||||
# Old path (expand then float32): [6, 100K, 256] fp32 = 586 MB
|
|
||||||
# New path (keep kv_h, float32): [1, 100K, 256] fp32 = 98 MB
|
|
||||||
# GQA grouping is handled lazily inside the Q-tile matmul via 4D
|
|
||||||
# reshaping, so no extra tensors are created.
|
|
||||||
try:
|
try:
|
||||||
_ATTN_Q_CHUNK = 64 # [kv_h, gqa, Q_CHUNK, kv_len] fp32 ≤ 150 MB
|
# Paged-block tiles for context phase.
|
||||||
|
# tile_sz = _BLOCKS_PER_TILE × block_size (e.g. 16×16 = 256 tokens).
|
||||||
|
# Score tensor [kv_h, gqa, q_len, tile_sz] fp32 = 24 MB per tile.
|
||||||
|
# Same tile size reused for the current-chunk phase.
|
||||||
|
_BLOCKS_PER_TILE = 32
|
||||||
|
|
||||||
batch_size = seq_lens_tensor.shape[0]
|
batch_size = seq_lens_tensor.shape[0]
|
||||||
num_q_heads = query.shape[1]
|
num_q_heads = query.shape[1]
|
||||||
num_kv_heads = key_cache.shape[1]
|
num_kv_heads = key_cache.shape[1]
|
||||||
head_dim = query.shape[2]
|
head_dim = query.shape[2]
|
||||||
gqa_ratio = num_q_heads // num_kv_heads
|
gqa_ratio = num_q_heads // num_kv_heads
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
# value_cache: [num_blocks, num_kv_heads, head_dim, block_size]
|
tile_sz = _BLOCKS_PER_TILE * block_size
|
||||||
block_size = value_cache.shape[3]
|
scale = head_dim ** -0.5
|
||||||
|
orig_dtype = query.dtype
|
||||||
scale = 1.0 / (head_dim ** 0.5)
|
output = torch.empty_like(query)
|
||||||
output = torch.empty_like(query)
|
dev = query.device
|
||||||
orig_dtype = query.dtype
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
ctx_len = int(context_lens[i].item())
|
ctx_len = int(context_lens[i].item())
|
||||||
@@ -374,96 +364,147 @@ class PagedAttention:
|
|||||||
q_end = int(query_start_loc[i + 1].item())
|
q_end = int(query_start_loc[i + 1].item())
|
||||||
q_len = q_end - q_start
|
q_len = q_end - q_start
|
||||||
|
|
||||||
q_i = query[q_start:q_end] # [q_len, num_q_heads, head_dim]
|
q_i = query[q_start:q_end] # [q_len, q_h, d]
|
||||||
k_i = key [q_start:q_end] # [q_len, num_kv_heads, head_dim]
|
k_i = key [q_start:q_end] # [q_len, kv_h, d]
|
||||||
v_i = value[q_start:q_end]
|
v_i = value[q_start:q_end]
|
||||||
|
|
||||||
# --- Build full K/V (context from cache + current chunk) ----
|
# Q reshaped and scaled once; held for all K-tiles.
|
||||||
|
# [kv_h, gqa, q_len, d] fp32 — 24 MB for q_len=4096, d=256
|
||||||
|
q_seq = (q_i.permute(1, 0, 2)
|
||||||
|
.float()
|
||||||
|
.view(num_kv_heads, gqa_ratio, q_len, head_dim)
|
||||||
|
.mul_(scale))
|
||||||
|
|
||||||
|
# Flash-Attention online-softmax accumulators.
|
||||||
|
# m, l : [kv_h, gqa, q_len] fp32 — <0.1 MB
|
||||||
|
# o : [kv_h, gqa, q_len, d] fp32 — 24 MB
|
||||||
|
m = torch.full((num_kv_heads, gqa_ratio, q_len),
|
||||||
|
float('-inf'), dtype=torch.float32, device=dev)
|
||||||
|
l = torch.zeros_like(m)
|
||||||
|
o = torch.zeros((num_kv_heads, gqa_ratio, q_len, head_dim),
|
||||||
|
dtype=torch.float32, device=dev)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------
|
||||||
|
# Phase 1 — context tokens (positions 0 … ctx_len-1).
|
||||||
|
#
|
||||||
|
# Every context key has absolute position < ctx_len; every
|
||||||
|
# query has position ≥ ctx_len. k_pos < q_pos is always True
|
||||||
|
# → no causal mask needed for pure context tiles.
|
||||||
|
# --------------------------------------------------------------
|
||||||
if ctx_len > 0:
|
if ctx_len > 0:
|
||||||
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
|
num_ctx_blocks = (ctx_len + block_size - 1) // block_size
|
||||||
blk_ids = block_tables[i, :num_ctx_blocks]
|
for tile_blk in range(0, num_ctx_blocks, _BLOCKS_PER_TILE):
|
||||||
|
blk_end = min(tile_blk + _BLOCKS_PER_TILE, num_ctx_blocks)
|
||||||
|
blk_ids = block_tables[i, tile_blk:blk_end]
|
||||||
|
|
||||||
# key_cache[blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
# Gather K/V for this tile.
|
||||||
# → permute(0,3,1,2,4) → contiguous → view → [:ctx_len]
|
# key_cache [blk_ids]: [n, kv_h, d//x, blk_sz, x]
|
||||||
k_ctx = (key_cache[blk_ids]
|
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
||||||
.permute(0, 3, 1, 2, 4)
|
k_tile = (key_cache[blk_ids]
|
||||||
.contiguous()
|
.permute(0, 3, 1, 2, 4)
|
||||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
.contiguous()
|
||||||
|
.view(-1, num_kv_heads, head_dim))
|
||||||
|
v_tile = (value_cache[blk_ids]
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(-1, num_kv_heads, head_dim))
|
||||||
|
|
||||||
# value_cache[blk_ids]: [n, kv_h, d, blk_sz]
|
# Trim padding in the last block of the tile.
|
||||||
# → permute(0,3,1,2) → contiguous → view → [:ctx_len]
|
valid = (min(blk_end * block_size, ctx_len)
|
||||||
v_ctx = (value_cache[blk_ids]
|
- tile_blk * block_size)
|
||||||
.permute(0, 3, 1, 2)
|
k_tile = k_tile[:valid] # [valid, kv_h, d]
|
||||||
.contiguous()
|
v_tile = v_tile[:valid]
|
||||||
.view(-1, num_kv_heads, head_dim))[:ctx_len]
|
|
||||||
|
|
||||||
k_full = torch.cat([k_ctx, k_i], dim=0) # [kv_len, kv_h, d]
|
# k_t: [kv_h, 1, d, valid] (broadcast over gqa_ratio)
|
||||||
v_full = torch.cat([v_ctx, v_i], dim=0)
|
# v_t: [kv_h, 1, valid, d]
|
||||||
del k_ctx, v_ctx
|
k_t = (k_tile.permute(1, 0, 2)
|
||||||
else:
|
.unsqueeze(1)
|
||||||
k_full = k_i
|
.transpose(-1, -2)
|
||||||
v_full = v_i
|
.float())
|
||||||
|
v_t = (v_tile.permute(1, 0, 2)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.float())
|
||||||
|
del k_tile, v_tile
|
||||||
|
|
||||||
kv_len = k_full.shape[0] # ctx_len + q_len
|
# Scores: [kv_h, gqa, q_len, valid]
|
||||||
|
s = torch.matmul(q_seq, k_t)
|
||||||
|
del k_t
|
||||||
|
# No causal mask: all context keys precede all queries.
|
||||||
|
|
||||||
# Transpose to [kv_h, kv_len, d], keep original dtype (fp16/bf16).
|
# Online softmax update — Flash-Attention Algorithm 1.
|
||||||
# Do NOT cast to fp32 here — k/v stay in fp16 to halve memory.
|
# exp_s = s - new_max (in-place exp after del s)
|
||||||
# attn_w is computed in fp32 (q cast to fp32 before matmul, then
|
m_blk = s.amax(dim=-1)
|
||||||
# k cast inline) so softmax precision is unaffected.
|
m_new = torch.maximum(m, m_blk)
|
||||||
# Do NOT expand GQA heads here either — gqa_ratio x memory savings.
|
exp_s = s - m_new.unsqueeze(-1)
|
||||||
k_t = k_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
del s
|
||||||
del k_full
|
exp_s.exp_()
|
||||||
v_t = v_full.permute(1, 0, 2).contiguous() # [kv_h, kv_len, d] fp16
|
corr = torch.exp(m - m_new)
|
||||||
del v_full
|
m.copy_(m_new)
|
||||||
|
del m_blk, m_new
|
||||||
|
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||||
|
o.mul_(corr.unsqueeze(-1)).add_(
|
||||||
|
torch.matmul(exp_s, v_t))
|
||||||
|
del exp_s, v_t, corr
|
||||||
|
|
||||||
# k_pos used for causal mask: shape [kv_len]
|
# --------------------------------------------------------------
|
||||||
k_pos = torch.arange(kv_len, device=query.device)
|
# Phase 2 — current-chunk tokens (positions ctx_len … ctx_len+q_len-1).
|
||||||
|
#
|
||||||
|
# Causal mask: query at relative position j sees key at relative
|
||||||
|
# position k only when k ≤ j. Tiles of tile_sz tokens each.
|
||||||
|
# --------------------------------------------------------------
|
||||||
|
for kc_start in range(0, q_len, tile_sz):
|
||||||
|
kc_end = min(kc_start + tile_sz, q_len)
|
||||||
|
kc_len = kc_end - kc_start
|
||||||
|
|
||||||
# --- Query-chunked attention with lazy GQA grouping ----------
|
k_blk = k_i[kc_start:kc_end] # [kc_len, kv_h, d]
|
||||||
# q_i reshaped to [kv_h, gqa_ratio, qc, d] so matmul with
|
v_blk = v_i[kc_start:kc_end]
|
||||||
# k_t [kv_h, kv_len, d] (broadcast over gqa_ratio dim) gives
|
|
||||||
# attn_w [kv_h, gqa_ratio, qc, kv_len] without extra K copies.
|
|
||||||
for qc_start in range(0, q_len, _ATTN_Q_CHUNK):
|
|
||||||
qc_end = min(qc_start + _ATTN_Q_CHUNK, q_len)
|
|
||||||
qc = qc_end - qc_start
|
|
||||||
|
|
||||||
# [kv_h, gqa_ratio, qc, d]
|
k_t = (k_blk.permute(1, 0, 2)
|
||||||
q_t_chunk = (q_i[qc_start:qc_end]
|
.unsqueeze(1)
|
||||||
.permute(1, 0, 2) # [q_h, qc, d]
|
.transpose(-1, -2)
|
||||||
.float()
|
.float()) # [kv_h, 1, d, kc_len]
|
||||||
.view(num_kv_heads, gqa_ratio, qc, head_dim))
|
v_t = (v_blk.permute(1, 0, 2)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.float()) # [kv_h, 1, kc_len, d]
|
||||||
|
|
||||||
# [kv_h, gqa_ratio, qc, kv_len]
|
s = torch.matmul(q_seq, k_t) # [kv_h, gqa, q_len, kc_len]
|
||||||
# k_t unsqueezed to [kv_h, 1, kv_len, d] broadcasts over gqa_ratio.
|
del k_t
|
||||||
# Cast k slice to fp32 inline; the temporary is freed after matmul.
|
|
||||||
attn_w = torch.matmul(q_t_chunk * scale,
|
|
||||||
k_t.unsqueeze(1).transpose(-1, -2).float())
|
|
||||||
|
|
||||||
# Causal mask for this sub-chunk:
|
# Causal mask: key at (kc_start+k) must not exceed query j.
|
||||||
# query absolute position = ctx_len + qc_start..qc_end-1
|
k_rel = torch.arange(kc_start, kc_end, device=dev)
|
||||||
qc_q_pos = torch.arange(qc_start, qc_end,
|
q_rel = torch.arange(q_len, device=dev)
|
||||||
device=query.device)
|
mask = k_rel.unsqueeze(0) > q_rel.unsqueeze(1) # [q_len, kc_len]
|
||||||
# mask[j, k] = True → future key, block it
|
s.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
||||||
mask = k_pos.unsqueeze(0) > (ctx_len + qc_q_pos.unsqueeze(1))
|
del mask, k_rel, q_rel
|
||||||
attn_w.masked_fill_(
|
|
||||||
mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
||||||
|
|
||||||
# In-place numerically stable softmax — avoids allocating a
|
# Online softmax update (identical to context phase).
|
||||||
# new 150 MB tensor (same size as attn_w) that torch.softmax
|
m_blk = s.amax(dim=-1)
|
||||||
# would create, which exhausts the fragmented GPU pool.
|
m_new = torch.maximum(m, m_blk)
|
||||||
attn_w -= attn_w.amax(dim=-1, keepdim=True)
|
exp_s = s - m_new.unsqueeze(-1)
|
||||||
attn_w.exp_()
|
del s
|
||||||
attn_w /= attn_w.sum(dim=-1, keepdim=True)
|
exp_s.exp_()
|
||||||
# [kv_h, gqa_ratio, qc, d]; v_t cast to fp32 inline
|
corr = torch.exp(m - m_new)
|
||||||
out_c = torch.matmul(attn_w,
|
m.copy_(m_new)
|
||||||
v_t.unsqueeze(1).float())
|
del m_blk, m_new
|
||||||
# reshape to [q_h, qc, d] then [qc, q_h, d]
|
l.mul_(corr).add_(exp_s.sum(dim=-1))
|
||||||
out_c = out_c.view(num_q_heads, qc, head_dim)
|
o.mul_(corr.unsqueeze(-1)).add_(
|
||||||
|
torch.matmul(exp_s, v_t))
|
||||||
|
del exp_s, v_t, corr
|
||||||
|
|
||||||
|
# --------------------------------------------------------------
|
||||||
|
# Finalize: normalize running output by normalization factor.
|
||||||
|
# o: [kv_h, gqa, q_len, d] → [q_len, q_h, d]
|
||||||
|
# --------------------------------------------------------------
|
||||||
|
o.div_(l.unsqueeze(-1))
|
||||||
|
output[q_start:q_end] = (
|
||||||
|
o.view(num_q_heads, q_len, head_dim)
|
||||||
|
.permute(1, 0, 2)
|
||||||
|
.to(orig_dtype)
|
||||||
|
)
|
||||||
|
|
||||||
output[q_start + qc_start : q_start + qc_end] = (
|
|
||||||
out_c.to(orig_dtype).permute(1, 0, 2))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[paged_attn ERROR] {type(e).__name__}: {e}", file=sys.stderr, flush=True)
|
print(f"[paged_attn ERROR] {type(e).__name__}: {e}",
|
||||||
|
file=sys.stderr, flush=True)
|
||||||
traceback.print_exc(file=sys.stderr)
|
traceback.print_exc(file=sys.stderr)
|
||||||
raise
|
raise
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -10,18 +10,11 @@
|
|||||||
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
|
# GPU hang on BI-V100 because the Triton CUDA PTX kernels are incompatible.
|
||||||
#
|
#
|
||||||
# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
|
# Important Note: Qwen3.6-27B must apply TP=4,PP=2 combination in order to deploy using 8 GPUs
|
||||||
#
|
|
||||||
# Recommended server start command for TP=4, context length: 50K, no chunked prefill mechanism:
|
|
||||||
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
|
|
||||||
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
|
|
||||||
# --max-model-len 50000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.90 \
|
|
||||||
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
|
||||||
# --max-num-batched-tokens 50000
|
|
||||||
|
|
||||||
# Recommended server start command for TP=4 support 100K, need chunked prefill
|
# Recommended server start command for TP=4 support 100K, need chunked prefill
|
||||||
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
|
# CUDA_VISIBLE_DEVICES="4,5,6,7" VLLM_ENGINE_ITERATION_TIMEOUT_S=3600 python3 -m vllm.entrypoints.openai.api_server \
|
||||||
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
|
# --model /workspace/models/Qwen3.6-27B --port 1111 --served-model-name llm \
|
||||||
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 8 --gpu-memory-utilization 0.95 \
|
# --max-model-len 100000 --enforce-eager --trust-remote-code -tp 4 --gpu-memory-utilization 0.95 \
|
||||||
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
# --max-num-seqs 1 --disable-log-requests --disable-frontend-multiprocessing \
|
||||||
# --max-num-batched-tokens 4096 --enable-chunked-prefill
|
# --max-num-batched-tokens 4096 --enable-chunked-prefill
|
||||||
|
|
||||||
@@ -29,13 +22,14 @@
|
|||||||
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
|
# The Triton context_attention_fwd kernel hangs BI-V100 GPUs permanently
|
||||||
# (standard Triton 2.3.1 PTX is not supported by the corex runtime either).
|
# (standard Triton 2.3.1 PTX is not supported by the corex runtime either).
|
||||||
# Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which
|
# Our paged_attn.py bypasses it entirely via _forward_prefix_pytorch, which
|
||||||
# also implements query-chunking (_ATTN_Q_CHUNK=256) to keep peak attention
|
# utilizes K-tiling techniques, and also have _forward_decode_pytorch to bypass kernel
|
||||||
# memory at O(256 × kv_len) instead of O(q_len × kv_len).
|
# when context length is high
|
||||||
cp ./paged_attn.py /usr/local/corex/lib64/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
cp ./paged_attn.py /usr/local/corex/lib/python3/dist-packages/vllm/attention/ops/paged_attn.py
|
||||||
|
|
||||||
# --- transformers: Qwen3_5 tokenizer / model files --------------------------
|
# --- transformers: Qwen3_5 tokenizer / model files --------------------------
|
||||||
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
pip install transformers==4.55.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
cp -r ./qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||||
|
cp -r ./qwen3_5_moe /usr/local/lib/python3.10/site-packages/transformers/models/
|
||||||
python3 ./patch_transformers_qwen3_5.py
|
python3 ./patch_transformers_qwen3_5.py
|
||||||
|
|
||||||
# --- vllm model: Qwen3.6-27B (Qwen3_5 arch) --------------------------------
|
# --- vllm model: Qwen3.6-27B (Qwen3_5 arch) --------------------------------
|
||||||
@@ -43,10 +37,36 @@ cp ./mamba_cache.py /usr/local/corex/lib/python3/dist-packages/vllm/model_execut
|
|||||||
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/qwen3_5.py
|
cp ./qwen3_5.py /usr/local/corex/lib/python3/dist-packages/vllm/model_executor/models/qwen3_5.py
|
||||||
python3 ./patch_vllm_qwen3_5.py
|
python3 ./patch_vllm_qwen3_5.py
|
||||||
|
|
||||||
|
# --- sequence.py: fix completion_tokens inflation under chunked prefill ------
|
||||||
|
# Bug: get_output_token_ids_to_return(delta=True) with num_new_tokens=0
|
||||||
|
# returns _cached_all_token_ids[-0:] == [0:] (the ENTIRE prompt+output list).
|
||||||
|
# Each prefill chunk step adds prompt_len to previous_num_tokens, so a 10K
|
||||||
|
# prompt processed in 3 chunks inflates completion_tokens by ~30K.
|
||||||
|
cp ./sequence.py /usr/local/corex/lib/python3/dist-packages/vllm/sequence.py
|
||||||
|
|
||||||
# --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------
|
# --- xformers: bypass cudnnFlashAttnForward (head_dim=256 > 128 limit) ------
|
||||||
# Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py.
|
# Injects _run_sdpa_fallback (pure matmul+softmax) into xformers.py.
|
||||||
# Required because head_dim=256 > 128 and ixformer flash attention either
|
# Required because head_dim=256 > 128 and ixformer flash attention either
|
||||||
# crashes (is_causal=True) or produces wrong output (attn_mask path).
|
# crashes (is_causal=True) or produces wrong output (attn_mask path).
|
||||||
# The fallback uses query_start_loc to derive actual query lengths, so it
|
# The fallback uses query_start_loc to derive actual query lengths, so it
|
||||||
# works correctly during profiling runs with chunked-prefill-style batches.
|
# works correctly during profiling runs with chunked-prefill-style batches.
|
||||||
|
# also bypasses auto chunked prefill on
|
||||||
python3 ./patch_xformers_sdpa_seq.py
|
python3 ./patch_xformers_sdpa_seq.py
|
||||||
|
|
||||||
|
# --- tool parser: Qwen3 XML tool call format ---------------------------------
|
||||||
|
# Registers "qwen3_coder" parser for Qwen3.6 XML-style tool calls:
|
||||||
|
# <tool_call><function=name><parameter=key>\nvalue\n</parameter></function></tool_call>
|
||||||
|
# Use at server start: --tool-call-parser qwen3_coder --enable-auto-tool-choice
|
||||||
|
cp ./qwen3coder_tool_parser.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
||||||
|
python3 ./patch_vllm_tool_parser.py
|
||||||
|
|
||||||
|
# --- reasoning parser: Qwen3 <think>...</think> split ------------------------
|
||||||
|
# Adds --reasoning-parser qwen3 support.
|
||||||
|
# Routes thinking tokens to reasoning_content, rest to content in the delta.
|
||||||
|
# Works together with --tool-call-parser qwen3_coder (think → tool call flow).
|
||||||
|
cp -r ./reasoning /usr/local/corex/lib/python3/dist-packages/vllm/
|
||||||
|
cp ./protocol.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/protocol.py
|
||||||
|
cp ./cli_args.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/cli_args.py
|
||||||
|
cp ./serving_chat.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/serving_chat.py
|
||||||
|
cp ./api_server.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/api_server.py
|
||||||
|
cp ./chat_utils.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/chat_utils.py
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Patches transformers 4.55.3 to register the qwen3_5 model type.
|
Patches transformers 4.55.3 to register qwen3_5 and qwen3_5_moe model types.
|
||||||
|
|
||||||
Deploy steps on the remote machine:
|
Deploy steps on the remote machine:
|
||||||
1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5
|
1. cp -r modified_scripts/qwen3_5 /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5
|
||||||
2. python3 modified_scripts/patch_transformers_qwen3_5.py
|
2. cp -r modified_scripts/qwen3_5_moe /usr/local/lib/python3.10/site-packages/transformers/models/qwen3_5_moe
|
||||||
|
3. python3 modified_scripts/patch_transformers_qwen3_5.py
|
||||||
|
|
||||||
Target: pip-installed transformers at /usr/local/lib/python3.10/site-packages/transformers/
|
Target: pip-installed transformers at /usr/local/lib/python3.10/site-packages/transformers/
|
||||||
(Not the corex pre-installed path at /usr/local/corex/lib64/python3/dist-packages/)
|
(Not the corex pre-installed path at /usr/local/corex/lib64/python3/dist-packages/)
|
||||||
@@ -40,24 +41,23 @@ def patch_file(path, replacements):
|
|||||||
def main():
|
def main():
|
||||||
print(f"=== Patching {AUTO_CONFIG} ===")
|
print(f"=== Patching {AUTO_CONFIG} ===")
|
||||||
patch_file(AUTO_CONFIG, [
|
patch_file(AUTO_CONFIG, [
|
||||||
# CONFIG_MAPPING_NAMES: insert qwen3_5 right after qwen3
|
# CONFIG_MAPPING_NAMES: insert qwen3_5 + qwen3_5_moe right after qwen3
|
||||||
(
|
(
|
||||||
'("qwen3", "Qwen3Config"),',
|
'("qwen3", "Qwen3Config"),',
|
||||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),',
|
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n ("qwen3_5_moe", "Qwen3_5MoeConfig"),',
|
||||||
),
|
),
|
||||||
# Some versions don't have trailing comma — handle that too
|
|
||||||
(
|
(
|
||||||
'("qwen3", "Qwen3Config")\n',
|
'("qwen3", "Qwen3Config")\n',
|
||||||
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n',
|
'("qwen3", "Qwen3Config"),\n ("qwen3_5", "Qwen3_5Config"),\n ("qwen3_5_moe", "Qwen3_5MoeConfig"),\n',
|
||||||
),
|
),
|
||||||
# MODEL_NAMES_MAPPING (model_type -> human readable name, used by docstring generator)
|
# MODEL_NAMES_MAPPING (model_type -> human readable name)
|
||||||
(
|
(
|
||||||
'("qwen3", "Qwen3"),',
|
'("qwen3", "Qwen3"),',
|
||||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),',
|
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n ("qwen3_5_moe", "Qwen3_5_MoE"),',
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
'("qwen3", "Qwen3")\n',
|
'("qwen3", "Qwen3")\n',
|
||||||
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n',
|
'("qwen3", "Qwen3"),\n ("qwen3_5", "Qwen3_5"),\n ("qwen3_5_moe", "Qwen3_5_MoE"),\n',
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
patch_file(MODELS_INIT, [
|
patch_file(MODELS_INIT, [
|
||||||
(
|
(
|
||||||
"from .qwen3 import *\n",
|
"from .qwen3 import *\n",
|
||||||
"from .qwen3 import *\n from .qwen3_5 import *\n",
|
"from .qwen3 import *\n from .qwen3_5 import *\n from .qwen3_5_moe import *\n",
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -74,19 +74,39 @@ def main():
|
|||||||
try:
|
try:
|
||||||
import importlib.util, types
|
import importlib.util, types
|
||||||
|
|
||||||
# Quick smoke-test: import the config class directly
|
def _load_config_mod(module_name, file_path):
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
"configuration_qwen3_5",
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
mod.__package__ = ".".join(module_name.split(".")[:-1])
|
||||||
|
pkg = sys.modules.setdefault("transformers", types.ModuleType("transformers"))
|
||||||
|
pkg.__path__ = [TRANSFORMERS_ROOT]
|
||||||
|
cu = sys.modules.setdefault(
|
||||||
|
"transformers.configuration_utils", types.ModuleType("transformers.configuration_utils"))
|
||||||
|
class _PC:
|
||||||
|
def __init__(self, **kwargs): pass
|
||||||
|
cu.PretrainedConfig = _PC
|
||||||
|
for sub in ("transformers.models", f"transformers.models.{module_name.split('.')[-2]}"):
|
||||||
|
m = sys.modules.setdefault(sub, types.ModuleType(sub))
|
||||||
|
m.__path__ = [TRANSFORMERS_ROOT]
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
mod27 = _load_config_mod(
|
||||||
|
"transformers.models.qwen3_5.configuration_qwen3_5",
|
||||||
f"{TRANSFORMERS_ROOT}/models/qwen3_5/configuration_qwen3_5.py",
|
f"{TRANSFORMERS_ROOT}/models/qwen3_5/configuration_qwen3_5.py",
|
||||||
)
|
)
|
||||||
mod = importlib.util.module_from_spec(spec)
|
cfg = mod27.Qwen3_5Config()
|
||||||
# Provide minimal parent package stubs so relative imports resolve
|
print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})")
|
||||||
pkg = types.ModuleType("transformers")
|
|
||||||
pkg.__path__ = [TRANSFORMERS_ROOT]
|
mod35 = _load_config_mod(
|
||||||
sys.modules.setdefault("transformers", pkg)
|
"transformers.models.qwen3_5_moe.configuration_qwen3_5_moe",
|
||||||
spec.loader.exec_module(mod)
|
f"{TRANSFORMERS_ROOT}/models/qwen3_5_moe/configuration_qwen3_5_moe.py",
|
||||||
cfg = mod.Qwen3_5Config()
|
)
|
||||||
print(f" Qwen3_5Config() smoke-test OK (model_type={cfg.model_type})")
|
moe_cfg = mod35.Qwen3_5MoeConfig()
|
||||||
|
print(f" Qwen3_5MoeConfig() smoke-test OK (model_type={moe_cfg.model_type})")
|
||||||
|
t = moe_cfg.text_config
|
||||||
|
print(f" num_experts={t.num_experts}, top_k={t.num_experts_per_tok}, "
|
||||||
|
f"shared={t.shared_expert_intermediate_size}, layers={t.num_hidden_layers}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [warn] smoke-test failed (may be fine at runtime): {e}")
|
print(f" [warn] smoke-test failed (may be fine at runtime): {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ def main():
|
|||||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),',
|
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),',
|
||||||
' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n'
|
' "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),\n'
|
||||||
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),\n'
|
' "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),\n'
|
||||||
' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),',
|
' "Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),\n'
|
||||||
|
' "Qwen3_5MoeForCausalLM": ("qwen3_5", "Qwen3_5MoeForCausalLM"),',
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -61,11 +62,13 @@ def main():
|
|||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
cls = mod.Qwen3_5ForCausalLM
|
cls = mod.Qwen3_5ForCausalLM
|
||||||
print(f" Qwen3_5ForCausalLM found: {cls}")
|
print(f" Qwen3_5ForCausalLM found: {cls}")
|
||||||
|
cls_moe = mod.Qwen3_5MoeForCausalLM
|
||||||
|
print(f" Qwen3_5MoeForCausalLM found: {cls_moe}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" [warn] verification failed (may be OK at runtime): {e}")
|
print(f" [warn] verification failed (may be OK at runtime): {e}")
|
||||||
|
|
||||||
print("\nDone. Remember to:")
|
print("\nDone. Remember to:")
|
||||||
print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM']")
|
print(" 1. Set config.json 'architectures': ['Qwen3_5ForCausalLM'] or ['Qwen3_5MoEForCausalLM']")
|
||||||
print(" 2. Run patch_transformers_qwen3_5.py if not already done")
|
print(" 2. Run patch_transformers_qwen3_5.py if not already done")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
79
qwen3_6_scripts/patch_vllm_tool_parser.py
Normal file
79
qwen3_6_scripts/patch_vllm_tool_parser.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
Patches vLLM 0.6.3 to register Qwen3CoderToolParser under the name "qwen3_coder".
|
||||||
|
|
||||||
|
Deploy steps on the remote machine (already called by patch_ops.sh):
|
||||||
|
1. cp qwen3coder_tool_parser.py \
|
||||||
|
/usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
||||||
|
2. python3 patch_vllm_tool_parser.py
|
||||||
|
|
||||||
|
Usage after patching:
|
||||||
|
--tool-call-parser qwen3_coder --enable-auto-tool-choice
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
VLLM_ROOT = "/usr/local/corex/lib/python3/dist-packages/vllm"
|
||||||
|
TOOL_PARSERS_DIR = f"{VLLM_ROOT}/entrypoints/openai/tool_parsers"
|
||||||
|
INIT_FILE = f"{TOOL_PARSERS_DIR}/__init__.py"
|
||||||
|
|
||||||
|
|
||||||
|
def patch_file(path, replacements):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
patched = False
|
||||||
|
for old, new in replacements:
|
||||||
|
if new in content:
|
||||||
|
print(f" [skip] already patched: {repr(new[:70])}")
|
||||||
|
continue
|
||||||
|
if old not in content:
|
||||||
|
print(f" [warn] anchor not found: {repr(old[:70])}")
|
||||||
|
continue
|
||||||
|
content = content.replace(old, new, 1)
|
||||||
|
patched = True
|
||||||
|
print(f" [ok] patched: {repr(old[:50])} -> {repr(new[:50])}")
|
||||||
|
|
||||||
|
if patched:
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not os.path.isdir(TOOL_PARSERS_DIR):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tool parsers directory not found: {TOOL_PARSERS_DIR}\n"
|
||||||
|
"Verify the vLLM installation path.")
|
||||||
|
|
||||||
|
print(f"=== Patching {INIT_FILE} ===")
|
||||||
|
patch_file(INIT_FILE, [
|
||||||
|
(
|
||||||
|
"from .mistral_tool_parser import MistralToolParser",
|
||||||
|
"from .mistral_tool_parser import MistralToolParser\n"
|
||||||
|
"from .qwen3coder_tool_parser import Qwen3CoderToolParser",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"\n]',
|
||||||
|
'"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",\n'
|
||||||
|
' "Qwen3CoderToolParser"\n]',
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
print("\n=== Verification ===")
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"qwen3coder_tool_parser",
|
||||||
|
f"{TOOL_PARSERS_DIR}/qwen3coder_tool_parser.py",
|
||||||
|
)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
print(f" Module spec loaded: {spec.name}")
|
||||||
|
print(" (full import requires torch/vllm runtime — skipping exec)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [warn] spec check failed: {e}")
|
||||||
|
|
||||||
|
print("\nDone. Start vLLM server with:")
|
||||||
|
print(" --tool-call-parser qwen3_coder --enable-auto-tool-choice")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1020
qwen3_6_scripts/protocol.py
Normal file
1020
qwen3_6_scripts/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,12 +11,15 @@ from torch import nn
|
|||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
@@ -417,9 +420,6 @@ class GatedDeltaNet(nn.Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Decode: one token per sequence
|
# Decode: one token per sequence
|
||||||
with open("/tmp/vllm_decode_debug.log", "a") as _f:
|
|
||||||
_f.write(f"[deltanet decode] layer={self.layer_idx} num_seqs={hidden_states.shape[0]}\n")
|
|
||||||
_f.flush()
|
|
||||||
num_seqs = hidden_states.shape[0]
|
num_seqs = hidden_states.shape[0]
|
||||||
weight_2d = self.conv1d_weight.squeeze(1)
|
weight_2d = self.conv1d_weight.squeeze(1)
|
||||||
|
|
||||||
@@ -449,17 +449,47 @@ class GatedDeltaNet(nn.Module):
|
|||||||
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
q = q.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
k = k.repeat_interleave(self.head_expand_ratio, dim=2)
|
||||||
|
|
||||||
core_out, last_state = _torch_recurrent_gated_delta_rule(
|
# Inlined decode recurrent step (seq_len=1).
|
||||||
q, k, v, g, beta,
|
# Replaces _torch_recurrent_gated_delta_rule to avoid 5 transpose+
|
||||||
initial_state=temporal_state,
|
# contiguous+float32 copies, core_out allocation, and Python loop.
|
||||||
output_final_state=True,
|
# Uses bmm/baddbmm_ to eliminate 3 large (B,H,k,v) intermediate tensors.
|
||||||
use_qk_l2norm_in_kernel=True,
|
# temporal_state: (B, H_v, k_dim, v_dim) float32 — updated in-place.
|
||||||
|
orig_dtype = q.dtype
|
||||||
|
_scale = self.head_k_dim ** -0.5
|
||||||
|
|
||||||
|
q_t = _l2norm(q.squeeze(1)).float() * _scale # (B, H_v, k_dim)
|
||||||
|
k_t = _l2norm(k.squeeze(1)).float() # (B, H_v, k_dim)
|
||||||
|
v_t = v.squeeze(1).float() # (B, H_v, v_dim)
|
||||||
|
g_t = g.squeeze(1).float().exp_() # (B, H_v)
|
||||||
|
bt = beta.squeeze(1).float() # (B, H_v)
|
||||||
|
|
||||||
|
# Decay state in-place: (B, H_v, k_dim, v_dim) *= scalar per head
|
||||||
|
temporal_state.mul_(g_t[:, :, None, None])
|
||||||
|
|
||||||
|
# Reshape to batched-matmul layout: (B*H_v, k_dim, v_dim)
|
||||||
|
ts_flat = temporal_state.view(-1, self.head_k_dim, self.head_v_dim)
|
||||||
|
BH = ts_flat.shape[0]
|
||||||
|
|
||||||
|
# kv_mem = k_t @ temporal_state shape: (B*H_v, 1, k_dim) @ (B*H_v, k_dim, v_dim)
|
||||||
|
kv_mem = torch.bmm(
|
||||||
|
k_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||||
|
).view(num_seqs, local_num_v, self.head_v_dim) # (B, H_v, v_dim)
|
||||||
|
|
||||||
|
delta = (v_t - kv_mem) * bt[:, :, None] # (B, H_v, v_dim)
|
||||||
|
|
||||||
|
# State update: temporal_state += outer(k_t, delta) fused, no intermediate
|
||||||
|
ts_flat.baddbmm_(
|
||||||
|
k_t.view(BH, self.head_k_dim, 1),
|
||||||
|
delta.view(BH, 1, self.head_v_dim),
|
||||||
)
|
)
|
||||||
if last_state is not None:
|
|
||||||
temporal_state.copy_(last_state)
|
# Output: core_out = q_t @ updated temporal_state
|
||||||
|
core_out = torch.bmm(
|
||||||
|
q_t.view(BH, 1, self.head_k_dim), ts_flat
|
||||||
|
).view(num_seqs, local_num_v, self.head_v_dim).to(orig_dtype)
|
||||||
|
# core_out: (B, H_v, v_dim) = (num_seqs, local_num_v, head_v_dim) already
|
||||||
|
|
||||||
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
z = z_all.reshape(num_seqs, local_num_v, self.head_v_dim)
|
||||||
core_out = core_out.reshape(num_seqs, local_num_v, self.head_v_dim)
|
|
||||||
normed = self.norm(
|
normed = self.norm(
|
||||||
core_out.reshape(-1, self.head_v_dim),
|
core_out.reshape(-1, self.head_v_dim),
|
||||||
z.reshape(-1, self.head_v_dim))
|
z.reshape(-1, self.head_v_dim))
|
||||||
@@ -495,24 +525,52 @@ class Qwen3_5FullAttention(nn.Module):
|
|||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.local_num_heads = self.num_heads // tp_size
|
self.local_num_heads = self.num_heads // tp_size
|
||||||
self.local_num_kv_heads = max(1, self.num_kv_heads // tp_size)
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
# When num_kv_heads < tp_size we cannot shard KV further (would give
|
||||||
|
# fractional heads per rank). Use ReplicatedLinear so every rank holds
|
||||||
|
# all KV heads; local_num_kv_heads equals the full count.
|
||||||
|
# When num_kv_heads >= tp_size standard ColumnParallel sharding applies.
|
||||||
|
if tp_size > self.num_kv_heads:
|
||||||
|
# GQA-aware TP sharding: ixformer kernel only supports num_kv_heads=1
|
||||||
|
# per rank. With num_kv_heads=2 < tp_size=4 we cannot shard KV
|
||||||
|
# evenly, but we CAN assign each rank the ONE KV head that serves
|
||||||
|
# its Q heads:
|
||||||
|
# q_per_kv = num_heads // num_kv_heads (e.g. 16//2 = 8)
|
||||||
|
# Rank r uses KV head r * local_num_heads // q_per_kv
|
||||||
|
# e.g. ranks 0,1 → KV head 0; ranks 2,3 → KV head 1.
|
||||||
|
# We replicate all KV heads to every rank and select in forward().
|
||||||
|
self.proj_kv_heads = self.num_kv_heads # heads available from projection
|
||||||
|
self.local_num_kv_heads = 1 # heads after rank-local selection
|
||||||
|
self.q_per_kv_global = self.num_heads // self.num_kv_heads
|
||||||
|
self.k_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||||
|
bias=False, quant_config=quant_config)
|
||||||
|
self.v_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||||
|
bias=False, quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
# Standard sharding: each rank gets num_kv_heads // tp_size heads.
|
||||||
|
self.local_num_kv_heads = self.num_kv_heads // tp_size
|
||||||
|
self.proj_kv_heads = self.local_num_kv_heads # already sharded
|
||||||
|
self.q_per_kv_global = None
|
||||||
|
self.k_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||||
|
bias=False, quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.k_proj")
|
||||||
|
self.v_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size, self.num_kv_heads * self.head_dim,
|
||||||
|
bias=False, quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.v_proj")
|
||||||
|
|
||||||
self.local_q_dim = self.local_num_heads * self.head_dim
|
self.local_q_dim = self.local_num_heads * self.head_dim
|
||||||
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
|
self.local_kv_dim = self.local_num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim ** -0.5
|
|
||||||
|
|
||||||
# q_proj includes gate: output = num_heads * head_dim * 2
|
# q_proj includes gate: output = num_heads * head_dim * 2
|
||||||
self.q_proj = ColumnParallelLinear(
|
self.q_proj = ColumnParallelLinear(
|
||||||
self.hidden_size, self.num_heads * self.head_dim * 2,
|
self.hidden_size, self.num_heads * self.head_dim * 2,
|
||||||
bias=False, quant_config=quant_config,
|
bias=False, quant_config=quant_config,
|
||||||
prefix=f"{prefix}.q_proj")
|
prefix=f"{prefix}.q_proj")
|
||||||
self.k_proj = ColumnParallelLinear(
|
|
||||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
|
||||||
bias=False, quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.k_proj")
|
|
||||||
self.v_proj = ColumnParallelLinear(
|
|
||||||
self.hidden_size, self.num_kv_heads * self.head_dim,
|
|
||||||
bias=False, quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.v_proj")
|
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.num_heads * self.head_dim, self.hidden_size,
|
self.num_heads * self.head_dim, self.hidden_size,
|
||||||
bias=False, quant_config=quant_config,
|
bias=False, quant_config=quant_config,
|
||||||
@@ -559,18 +617,34 @@ class Qwen3_5FullAttention(nn.Module):
|
|||||||
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
|
q = qg[:, :, :self.head_dim].reshape(total_tokens, -1)
|
||||||
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
|
gate = qg[:, :, self.head_dim:].reshape(total_tokens, -1)
|
||||||
|
|
||||||
k, _ = self.k_proj(hidden_states) # (total, local_kv_dim)
|
k, _ = self.k_proj(hidden_states) # (total, proj_kv_heads * head_dim)
|
||||||
v, _ = self.v_proj(hidden_states)
|
v, _ = self.v_proj(hidden_states)
|
||||||
|
|
||||||
# Per-head RMSNorm
|
# q_norm on local Q heads
|
||||||
q = self.q_norm.forward_cuda(
|
q = self.q_norm.forward_cuda(
|
||||||
q.view(total_tokens, self.local_num_heads, self.head_dim)
|
q.view(total_tokens, self.local_num_heads, self.head_dim)
|
||||||
.contiguous()).view(total_tokens, -1)
|
.contiguous()).view(total_tokens, -1)
|
||||||
|
|
||||||
|
# GQA-aware TP: select rank-local KV head BEFORE k_norm and rope so
|
||||||
|
# that ixformer kernels always see num_kv_heads=1 (same as 27B path).
|
||||||
|
# Doing k_norm/rope on 2 KV heads (proj_kv_heads=2) triggers ixformer
|
||||||
|
# paths that can produce NaN; restricting to 1 head avoids the issue.
|
||||||
|
if self.q_per_kv_global is not None:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
kv_idx = (tp_rank * self.local_num_heads) // self.q_per_kv_global
|
||||||
|
k = (k.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||||
|
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||||
|
v = (v.view(total_tokens, self.proj_kv_heads, self.head_dim)
|
||||||
|
[:, kv_idx, :].contiguous()) # (T, head_dim) — 1 head
|
||||||
|
|
||||||
|
# k_norm on the (now always 1) rank-local KV head
|
||||||
k = self.k_norm.forward_cuda(
|
k = self.k_norm.forward_cuda(
|
||||||
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
|
k.view(total_tokens, self.local_num_kv_heads, self.head_dim)
|
||||||
.contiguous()).view(total_tokens, -1)
|
.contiguous()).view(total_tokens, -1)
|
||||||
|
|
||||||
|
# rope: q=(T, local_num_heads*head_dim), k=(T, 1*head_dim) — mirrors 27B
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_out = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
|
||||||
# Multiply by sigmoid gate before output projection
|
# Multiply by sigmoid gate before output projection
|
||||||
@@ -609,10 +683,154 @@ class Qwen3_5MLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MoE sparse block (Qwen3.5-MoE / Qwen3.6-35B-A3B)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class Qwen3_5MoeSparseBlock(nn.Module):
|
||||||
|
"""Replaces Qwen3_5MLP for qwen3_5_moe_text layers.
|
||||||
|
|
||||||
|
FusedMoE is used ONLY for weight storage and loading (create_weights /
|
||||||
|
weight_loader are pure PyTorch). Its forward kernel is bypassed because
|
||||||
|
ixformer on BI-V100 lacks vllm_moe_topk_softmax / vllm_invoke_fused_moe_kernel.
|
||||||
|
Routing and expert computation use a pure-PyTorch loop instead.
|
||||||
|
|
||||||
|
Shared expert uses RowParallelLinear(reduce_results=False) so both paths
|
||||||
|
produce partial (pre-all-reduce) outputs that are combined before a single
|
||||||
|
all-reduce.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_cfg,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = text_cfg.hidden_size
|
||||||
|
self.num_experts = text_cfg.num_experts
|
||||||
|
self.top_k = text_cfg.num_experts_per_tok
|
||||||
|
|
||||||
|
# Router: replicated (small: num_experts outputs)
|
||||||
|
self.gate = ReplicatedLinear(hidden_size, text_cfg.num_experts,
|
||||||
|
bias=False, quant_config=quant_config)
|
||||||
|
|
||||||
|
# FusedMoE: only used for weight storage + weight_loader.
|
||||||
|
# Forward is bypassed — see _pure_pytorch_experts().
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=text_cfg.num_experts,
|
||||||
|
top_k=text_cfg.num_experts_per_tok,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=text_cfg.moe_intermediate_size,
|
||||||
|
reduce_results=False, # we do the all-reduce ourselves below
|
||||||
|
renormalize=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared expert: defer all-reduce to combine with routed output first
|
||||||
|
shared_size = text_cfg.shared_expert_intermediate_size
|
||||||
|
self.shared_expert_gate_up = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [shared_size] * 2, bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.shared_expert_down = RowParallelLinear(
|
||||||
|
shared_size, hidden_size, bias=False, reduce_results=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
# Scalar sigmoid gate on shared expert output (same as Qwen2-MoE / Qwen3.5-MoE):
|
||||||
|
# shared_out *= sigmoid(shared_expert_gate(hidden_states))
|
||||||
|
# Without this, shared expert is always fully active → wrong logits.
|
||||||
|
self.shared_expert_gate = ReplicatedLinear(
|
||||||
|
hidden_size, 1, bias=False, quant_config=quant_config)
|
||||||
|
|
||||||
|
def _pure_pytorch_experts(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Pure-PyTorch MoE (ixformer has no MoE kernels on BI-V100).
|
||||||
|
|
||||||
|
w13_weight: (num_experts, 2*inter_per_partition, hidden) [TP-sharded]
|
||||||
|
w2_weight: (num_experts, hidden, inter_per_partition) [TP-sharded]
|
||||||
|
Output is partial (pre-all-reduce), same contract as FusedMoE
|
||||||
|
with reduce_results=False.
|
||||||
|
"""
|
||||||
|
# Routing: softmax → topk → renormalise
|
||||||
|
routing_weights = torch.softmax(router_logits.float(), dim=-1)
|
||||||
|
topk_weights, topk_ids = torch.topk(
|
||||||
|
routing_weights, self.top_k, dim=-1) # (T, top_k)
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
w13 = self.experts.w13_weight # (E, 2*I, H)
|
||||||
|
w2 = self.experts.w2_weight # (E, H, I)
|
||||||
|
|
||||||
|
T = hidden_states.shape[0]
|
||||||
|
if T == 1:
|
||||||
|
# Fast path: single token (decode).
|
||||||
|
# Batched GEMM: replace top_k separate F.linear calls with 2 fused ops.
|
||||||
|
# gate_up: 1 large GEMM (1,H) × (K*2*I,H)^T → (1, K*2*I)
|
||||||
|
# down: 1 bmm (K,H,I) @ (K,I,1) → (K,H)
|
||||||
|
# Total: 3 kernel launches vs previous 16 (top_k*2).
|
||||||
|
eids = topk_ids[0] # (K,)
|
||||||
|
ws = topk_weights[0].to(hidden_states.dtype) # (K,)
|
||||||
|
w13_sel = w13[eids] # (K, 2*I, H)
|
||||||
|
w2_sel = w2[eids] # (K, H, I)
|
||||||
|
|
||||||
|
H = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
gate_up = F.linear(
|
||||||
|
hidden_states,
|
||||||
|
w13_sel.reshape(-1, H), # (K*2*I, H) — contiguous after indexing
|
||||||
|
) # (1, K*2*I)
|
||||||
|
gate_up = gate_up.view(self.top_k, -1) # (K, 2*I)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1) # (K, I) each
|
||||||
|
act = F.silu(gate) * up # (K, I)
|
||||||
|
|
||||||
|
# bmm: (K,H,I) @ (K,I,1) → (K,H,1) → (K,H)
|
||||||
|
expert_out = torch.bmm(w2_sel, act.unsqueeze(-1)).squeeze(-1) # (K, H)
|
||||||
|
|
||||||
|
out = (expert_out * ws.unsqueeze(-1)).sum(0, keepdim=True).to(
|
||||||
|
hidden_states.dtype) # (1, H)
|
||||||
|
else:
|
||||||
|
# General path (prefill / multi-seq): loop over unique active experts.
|
||||||
|
# At most T*top_k unique experts, always <= num_experts.
|
||||||
|
out = torch.zeros_like(hidden_states)
|
||||||
|
unique_eids = topk_ids.view(-1).unique().tolist()
|
||||||
|
for eid in unique_eids:
|
||||||
|
eid = int(eid)
|
||||||
|
mask = (topk_ids == eid) # (T, top_k)
|
||||||
|
tok_ids, topk_pos = mask.nonzero(as_tuple=True)
|
||||||
|
tokens = hidden_states[tok_ids] # (n, H)
|
||||||
|
gate_up = F.linear(tokens, w13[eid]) # (n, 2*I)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1)
|
||||||
|
act = F.silu(gate) * up # (n, I)
|
||||||
|
expert_out = F.linear(act, w2[eid]) # (n, H)
|
||||||
|
weights = topk_weights[tok_ids, topk_pos].unsqueeze(-1)
|
||||||
|
out.index_add_(0, tok_ids, (expert_out * weights).to(out.dtype))
|
||||||
|
|
||||||
|
return out # partial, all-reduce done in forward()
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
routed_out = self._pure_pytorch_experts(hidden_states, router_logits)
|
||||||
|
|
||||||
|
gate_up, _ = self.shared_expert_gate_up(hidden_states)
|
||||||
|
shared_out = self.act_fn(gate_up)
|
||||||
|
shared_out, _ = self.shared_expert_down(shared_out)
|
||||||
|
# Scalar sigmoid gate (Qwen2-MoE / Qwen3.5-MoE style)
|
||||||
|
gate_score, _ = self.shared_expert_gate(hidden_states) # (T, 1)
|
||||||
|
shared_out = shared_out * torch.sigmoid(gate_score)
|
||||||
|
|
||||||
|
out = routed_out + shared_out
|
||||||
|
if self.experts.tp_size > 1:
|
||||||
|
out = tensor_model_parallel_all_reduce(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
|
# Decoder layer (dispatches to GatedDeltaNet or Qwen3_5FullAttention)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_5DecoderLayer(nn.Module):
|
class Qwen3_5DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -623,6 +841,7 @@ class Qwen3_5DecoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
self.layer_type = layer_type
|
self.layer_type = layer_type
|
||||||
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
self.input_layernorm = GemmaRMSNorm(text_cfg.hidden_size,
|
||||||
eps=text_cfg.rms_norm_eps)
|
eps=text_cfg.rms_norm_eps)
|
||||||
@@ -640,12 +859,15 @@ class Qwen3_5DecoderLayer(nn.Module):
|
|||||||
prefix=f"layers.{layer_idx}.self_attn",
|
prefix=f"layers.{layer_idx}.self_attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = Qwen3_5MLP(
|
if getattr(text_cfg, 'model_type', '') == 'qwen3_5_moe_text':
|
||||||
hidden_size=text_cfg.hidden_size,
|
self.mlp = Qwen3_5MoeSparseBlock(text_cfg, quant_config=quant_config)
|
||||||
intermediate_size=text_cfg.intermediate_size,
|
else:
|
||||||
hidden_act=text_cfg.hidden_act,
|
self.mlp = Qwen3_5MLP(
|
||||||
quant_config=quant_config,
|
hidden_size=text_cfg.hidden_size,
|
||||||
)
|
intermediate_size=text_cfg.intermediate_size,
|
||||||
|
hidden_act=text_cfg.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -673,7 +895,9 @@ class Qwen3_5DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -860,8 +1084,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
|||||||
# With chunked prefill, intermediate chunks have seq_groups=None on all
|
# With chunked prefill, intermediate chunks have seq_groups=None on all
|
||||||
# ranks; _apply_logits_processors is guarded against this in
|
# ranks; _apply_logits_processors is guarded against this in
|
||||||
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
|
# logits_processor.py (patched by patch_xformers_sdpa_seq.py).
|
||||||
return self.logits_processor(self.lm_head, hidden_states,
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
@@ -892,12 +1117,9 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
|||||||
or name.startswith("model.mtp")):
|
or name.startswith("model.mtp")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remap checkpoint prefix → module path
|
# Prefix remapping: checkpoint may wrap under language_model
|
||||||
# Checkpoint: "model.language_model.{rest}" → our module: "model.{rest}"
|
|
||||||
# Checkpoint: "lm_head.weight" → our module: "lm_head.weight"
|
|
||||||
if name.startswith("model.language_model."):
|
if name.startswith("model.language_model."):
|
||||||
name = "model." + name[len("model.language_model."):]
|
name = "model." + name[len("model.language_model."):]
|
||||||
# lm_head is already at top level — no change needed
|
|
||||||
|
|
||||||
# Skip positional embedding caches
|
# Skip positional embedding caches
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
@@ -931,3 +1153,118 @@ class Qwen3_5ForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Qwen3.6-35B-A3B (Qwen3_5-MoE architecture)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLM):
|
||||||
|
"""Qwen3.6-35B-A3B: same hybrid-attention backbone as 27B, dense MLP
|
||||||
|
replaced by Qwen3_5MoeSparseBlock (256 routed experts + shared expert).
|
||||||
|
Only load_weights differs from the dense variant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
# Checkpoint key format for this model (transformers Qwen3_5MoeExperts):
|
||||||
|
# mlp.experts.gate_up_proj shape (num_experts, 2*intermediate, hidden)
|
||||||
|
# mlp.experts.down_proj shape (num_experts, hidden, intermediate)
|
||||||
|
# mlp.gate.weight shape (num_experts, hidden) [router]
|
||||||
|
# mlp.shared_expert.{gate,up,down}_proj.weight [shared MLP]
|
||||||
|
# Our FusedMoE stores:
|
||||||
|
# mlp.experts.w13_weight shape (num_experts, 2*intermediate//tp, hidden)
|
||||||
|
# mlp.experts.w2_weight shape (num_experts, hidden, intermediate//tp)
|
||||||
|
# Our shared expert stores:
|
||||||
|
# mlp.shared_expert_gate_up.weight (merged gate+up)
|
||||||
|
# mlp.shared_expert_down.weight
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, weight_name, shard_id)
|
||||||
|
# shared expert
|
||||||
|
("shared_expert_gate_up", "shared_expert.gate_proj", 0),
|
||||||
|
("shared_expert_gate_up", "shared_expert.up_proj", 1),
|
||||||
|
# linear_attention dense proj (same as 27B)
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
# Skip vision and MTP branches
|
||||||
|
if (name.startswith("model.visual")
|
||||||
|
or name.startswith("mtp.")
|
||||||
|
or name.startswith("model.mtp")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefix remapping for VL checkpoint (Qwen3_5MoeForConditionalGeneration):
|
||||||
|
# model.language_model.model.{layers,embed_tokens,norm} -> model.{...}
|
||||||
|
# model.language_model.lm_head -> lm_head
|
||||||
|
# Prefix remapping: checkpoint may wrap under language_model
|
||||||
|
if name.startswith("model.language_model."):
|
||||||
|
name = "model." + name[len("model.language_model."):]
|
||||||
|
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ".linear_attn.conv1d.weight" in name:
|
||||||
|
name = name.replace(".linear_attn.conv1d.weight",
|
||||||
|
".linear_attn.conv1d_weight")
|
||||||
|
|
||||||
|
# --- Fused routed-expert weights (all experts in one tensor) ---
|
||||||
|
|
||||||
|
if "mlp.experts.gate_up_proj" in name:
|
||||||
|
# loaded_weight: (num_experts, 2*intermediate, hidden)
|
||||||
|
w13_name = name.replace("mlp.experts.gate_up_proj",
|
||||||
|
"mlp.experts.w13_weight")
|
||||||
|
if w13_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[w13_name]
|
||||||
|
n_exp = loaded_weight.shape[0]
|
||||||
|
inter = loaded_weight.shape[1] // 2
|
||||||
|
gate_w = loaded_weight[:, :inter, :].contiguous()
|
||||||
|
up_w = loaded_weight[:, inter:, :].contiguous()
|
||||||
|
for eid in range(n_exp):
|
||||||
|
param.weight_loader(param, gate_w[eid], "w1_weight", "w1", eid)
|
||||||
|
param.weight_loader(param, up_w[eid], "w3_weight", "w3", eid)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "mlp.experts.down_proj" in name:
|
||||||
|
# loaded_weight: (num_experts, hidden, intermediate)
|
||||||
|
w2_name = name.replace("mlp.experts.down_proj",
|
||||||
|
"mlp.experts.w2_weight")
|
||||||
|
if w2_name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[w2_name]
|
||||||
|
n_exp = loaded_weight.shape[0]
|
||||||
|
for eid in range(n_exp):
|
||||||
|
param.weight_loader(param, loaded_weight[eid], "w2_weight", "w2", eid)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- Shared expert down_proj rename ---
|
||||||
|
if "mlp.shared_expert.down_proj" in name:
|
||||||
|
name = name.replace("mlp.shared_expert.down_proj",
|
||||||
|
"mlp.shared_expert_down")
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- Stacked / standard weights ---
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if name not in params_dict:
|
||||||
|
break
|
||||||
|
param = params_dict[name]
|
||||||
|
param.weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
3
qwen3_6_scripts/qwen3_5_moe/__init__.py
Normal file
3
qwen3_6_scripts/qwen3_5_moe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig
|
||||||
|
|
||||||
|
__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
|
||||||
198
qwen3_6_scripts/qwen3_5_moe/configuration_qwen3_5_moe.py
Normal file
198
qwen3_6_scripts/qwen3_5_moe/configuration_qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Adapted from transformers 5.2.0 for compatibility with transformers 4.55.3 + torch 2.1.0
|
||||||
|
# Source: transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
|
||||||
|
# Stubs layer_type_validation and RopeParameters which do not exist in 4.55.3
|
||||||
|
# Removes ignore_keys_at_rope_validation / base_model_tp_plan / base_model_pp_plan
|
||||||
|
# which are 5.x-only and irrelevant for vLLM inference.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig as PreTrainedConfig
|
||||||
|
|
||||||
|
# --- Local stubs for APIs not present in transformers 4.55.3 ---
|
||||||
|
def layer_type_validation(layer_types, num_hidden_layers=None, attention=True):
|
||||||
|
allowed = {"full_attention", "linear_attention"}
|
||||||
|
if not all(lt in allowed for lt in layer_types):
|
||||||
|
raise ValueError(f"layer_types entries must be in {allowed}, got {layer_types}")
|
||||||
|
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"num_hidden_layers ({num_hidden_layers}) != len(layer_types) ({len(layer_types)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypedDict
|
||||||
|
class RopeParameters(TypedDict, total=False):
|
||||||
|
rope_theta: float
|
||||||
|
rope_type: str
|
||||||
|
partial_rotary_factor: float
|
||||||
|
factor: float
|
||||||
|
except Exception:
|
||||||
|
RopeParameters = dict
|
||||||
|
|
||||||
|
# --- End stubs ---
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_5MoeTextConfig(PreTrainedConfig):
|
||||||
|
r"""
|
||||||
|
Configuration for the text backbone of Qwen3.5-MoE / Qwen3.6-35B-A3B models.
|
||||||
|
model_type is "qwen3_5_moe_text" (used internally by the nested config).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "qwen3_5_moe_text"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=248320,
|
||||||
|
hidden_size=2048,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=32768,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_parameters=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
head_dim=256,
|
||||||
|
linear_conv_kernel_dim=4,
|
||||||
|
linear_key_head_dim=128,
|
||||||
|
linear_value_head_dim=128,
|
||||||
|
linear_num_key_heads=16,
|
||||||
|
linear_num_value_heads=32,
|
||||||
|
moe_intermediate_size=512,
|
||||||
|
shared_expert_intermediate_size=512,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
num_experts=256,
|
||||||
|
output_router_logits=False,
|
||||||
|
router_aux_loss_coef=0.001,
|
||||||
|
layer_types=None,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.rope_parameters = rope_parameters
|
||||||
|
kwargs.setdefault("partial_rotary_factor", 0.25)
|
||||||
|
|
||||||
|
self.layer_types = layer_types
|
||||||
|
if self.layer_types is None:
|
||||||
|
interval_pattern = kwargs.get("full_attention_interval", 4)
|
||||||
|
self.layer_types = [
|
||||||
|
"linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
|
||||||
|
for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
||||||
|
|
||||||
|
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||||
|
self.linear_key_head_dim = linear_key_head_dim
|
||||||
|
self.linear_value_head_dim = linear_value_head_dim
|
||||||
|
self.linear_num_key_heads = linear_num_key_heads
|
||||||
|
self.linear_num_value_heads = linear_num_value_heads
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_5MoeVisionConfig(PreTrainedConfig):
|
||||||
|
model_type = "qwen3_5_moe"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth=27,
|
||||||
|
hidden_size=1152,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
intermediate_size=4304,
|
||||||
|
num_heads=16,
|
||||||
|
in_channels=3,
|
||||||
|
patch_size=16,
|
||||||
|
spatial_merge_size=2,
|
||||||
|
temporal_patch_size=2,
|
||||||
|
out_hidden_size=3584,
|
||||||
|
num_position_embeddings=2304,
|
||||||
|
initializer_range=0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.depth = depth
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.out_hidden_size = out_hidden_size
|
||||||
|
self.num_position_embeddings = num_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_5MoeConfig(PreTrainedConfig):
|
||||||
|
r"""
|
||||||
|
Top-level configuration for Qwen3.5-MoE / Qwen3.6-35B-A3B.
|
||||||
|
model_type = "qwen3_5_moe" matches the model card / config.json.
|
||||||
|
Wraps Qwen3_5MoeTextConfig (and optionally Qwen3_5MoeVisionConfig).
|
||||||
|
For vLLM text-only inference only text_config is consumed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "qwen3_5_moe"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_config=None,
|
||||||
|
vision_config=None,
|
||||||
|
image_token_id=248056,
|
||||||
|
video_token_id=248057,
|
||||||
|
vision_start_token_id=248053,
|
||||||
|
vision_end_token_id=248054,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
self.text_config = Qwen3_5MoeTextConfig(**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
self.text_config = Qwen3_5MoeTextConfig()
|
||||||
|
else:
|
||||||
|
self.text_config = text_config
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
self.vision_config = Qwen3_5MoeVisionConfig(**vision_config)
|
||||||
|
elif vision_config is None:
|
||||||
|
self.vision_config = Qwen3_5MoeVisionConfig()
|
||||||
|
else:
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.video_token_id = video_token_id
|
||||||
|
self.vision_start_token_id = vision_start_token_id
|
||||||
|
self.vision_end_token_id = vision_end_token_id
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig"]
|
||||||
509
qwen3_6_scripts/qwen3coder_tool_parser.py
Normal file
509
qwen3_6_scripts/qwen3coder_tool_parser.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionToolsParam,
|
||||||
|
DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("qwen3_coder")
|
||||||
|
class Qwen3CoderToolParser(ToolParser):
|
||||||
|
"""
|
||||||
|
Tool parser for Qwen3 models using XML-style tool call format:
|
||||||
|
<tool_call><function=name><parameter=key>
|
||||||
|
value
|
||||||
|
</parameter></function></tool_call>
|
||||||
|
|
||||||
|
Port of vllm-original qwen3coder_tool_parser.py to vllm 0.6.3 API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
self.current_tool_name_sent: bool = False
|
||||||
|
self.prev_tool_call_arr: List[Dict] = []
|
||||||
|
# Base class uses int; we override with string IDs
|
||||||
|
self.current_tool_id: Optional[str] = None # type: ignore[assignment]
|
||||||
|
self.streamed_args_for_tool: List[str] = []
|
||||||
|
|
||||||
|
self.tool_call_start_token: str = "<tool_call>"
|
||||||
|
self.tool_call_end_token: str = "</tool_call>"
|
||||||
|
self.tool_call_prefix: str = "<function="
|
||||||
|
self.function_end_token: str = "</function>"
|
||||||
|
self.parameter_prefix: str = "<parameter="
|
||||||
|
self.parameter_end_token: str = "</parameter>"
|
||||||
|
self.is_tool_call_started: bool = False
|
||||||
|
|
||||||
|
self._reset_streaming_state()
|
||||||
|
|
||||||
|
self.tool_call_complete_regex = re.compile(
|
||||||
|
r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||||
|
self.tool_call_regex = re.compile(
|
||||||
|
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
|
||||||
|
self.tool_call_function_regex = re.compile(
|
||||||
|
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
|
||||||
|
self.tool_call_parameter_regex = re.compile(
|
||||||
|
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
|
"constructor during construction.")
|
||||||
|
|
||||||
|
self.tool_call_start_token_id = self.vocab.get(
|
||||||
|
self.tool_call_start_token)
|
||||||
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||||
|
|
||||||
|
if (self.tool_call_start_token_id is None
|
||||||
|
or self.tool_call_end_token_id is None):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Qwen3 XML Tool parser could not locate tool call start/end "
|
||||||
|
"tokens in the tokenizer!")
|
||||||
|
|
||||||
|
logger.debug("vLLM Successfully imported tool parser %s !",
|
||||||
|
self.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_tool_call_id(self) -> str:
|
||||||
|
return f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def _reset_streaming_state(self) -> None:
|
||||||
|
self.current_tool_index = 0
|
||||||
|
self.is_tool_call_started = False
|
||||||
|
self.header_sent = False
|
||||||
|
self.current_tool_id = None
|
||||||
|
self.current_function_name: Optional[str] = None
|
||||||
|
self.current_param_name: Optional[str] = None
|
||||||
|
self.current_param_value: str = ""
|
||||||
|
self.param_count = 0
|
||||||
|
self.in_param = False
|
||||||
|
self.in_function = False
|
||||||
|
self.accumulated_text: str = ""
|
||||||
|
self.json_started = False
|
||||||
|
self.json_closed = False
|
||||||
|
self.accumulated_params: Dict[str, Any] = {}
|
||||||
|
self.streaming_request: Optional[ChatCompletionRequest] = None
|
||||||
|
|
||||||
|
def _get_arguments_config(
|
||||||
|
self, func_name: str,
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]]) -> Dict:
|
||||||
|
if tools is None:
|
||||||
|
return {}
|
||||||
|
for config in tools:
|
||||||
|
if not hasattr(config, "type") or not (
|
||||||
|
hasattr(config, "function")
|
||||||
|
and hasattr(config.function, "name")):
|
||||||
|
continue
|
||||||
|
if config.type == "function" and config.function.name == func_name:
|
||||||
|
if not hasattr(config.function, "parameters"):
|
||||||
|
return {}
|
||||||
|
params = config.function.parameters
|
||||||
|
if isinstance(params, dict) and "properties" in params:
|
||||||
|
return params["properties"]
|
||||||
|
elif isinstance(params, dict):
|
||||||
|
return params
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
logger.debug("Tool '%s' is not defined in the tools list.", func_name)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _convert_param_value(self, param_value: str, param_name: str,
|
||||||
|
param_config: Dict, func_name: str) -> Any:
|
||||||
|
if param_value.lower() == "null":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if param_name not in param_config:
|
||||||
|
if param_config != {}:
|
||||||
|
logger.debug(
|
||||||
|
"Parsed parameter '%s' is not defined in tool '%s', "
|
||||||
|
"returning string value.", param_name, func_name)
|
||||||
|
return param_value
|
||||||
|
|
||||||
|
if (isinstance(param_config[param_name], dict)
|
||||||
|
and "type" in param_config[param_name]):
|
||||||
|
param_type = str(
|
||||||
|
param_config[param_name]["type"]).strip().lower()
|
||||||
|
else:
|
||||||
|
param_type = "string"
|
||||||
|
|
||||||
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
||||||
|
return param_value
|
||||||
|
elif (param_type.startswith("int") or param_type.startswith("uint")
|
||||||
|
or param_type.startswith("long")
|
||||||
|
or param_type.startswith("short")
|
||||||
|
or param_type.startswith("unsigned")):
|
||||||
|
try:
|
||||||
|
return int(param_value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return param_value
|
||||||
|
elif param_type.startswith("num") or param_type.startswith("float"):
|
||||||
|
try:
|
||||||
|
v = float(param_value)
|
||||||
|
return int(v) if v - int(v) == 0 else v
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return param_value
|
||||||
|
elif param_type in ["boolean", "bool", "binary"]:
|
||||||
|
lower = param_value.lower()
|
||||||
|
if lower not in ["true", "false"]:
|
||||||
|
logger.debug(
|
||||||
|
"Parameter '%s' value '%s' is not boolean in tool '%s'.",
|
||||||
|
param_name, param_value, func_name)
|
||||||
|
return lower == "true"
|
||||||
|
else:
|
||||||
|
if (param_type in ["object", "array", "arr"]
|
||||||
|
or param_type.startswith("dict")
|
||||||
|
or param_type.startswith("list")):
|
||||||
|
try:
|
||||||
|
return json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(param_value)
|
||||||
|
except (ValueError, SyntaxError, TypeError):
|
||||||
|
pass
|
||||||
|
return param_value
|
||||||
|
|
||||||
|
def _parse_xml_function_call(
|
||||||
|
self, function_call_str: str,
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]]) -> ToolCall:
|
||||||
|
end_index = function_call_str.index(">")
|
||||||
|
function_name = function_call_str[:end_index]
|
||||||
|
param_config = self._get_arguments_config(function_name, tools)
|
||||||
|
parameters = function_call_str[end_index + 1:]
|
||||||
|
param_dict: Dict[str, Any] = {}
|
||||||
|
for match_text in self.tool_call_parameter_regex.findall(parameters):
|
||||||
|
idx = match_text.index(">")
|
||||||
|
param_name = match_text[:idx]
|
||||||
|
param_value = str(match_text[idx + 1:])
|
||||||
|
if param_value.startswith("\n"):
|
||||||
|
param_value = param_value[1:]
|
||||||
|
if param_value.endswith("\n"):
|
||||||
|
param_value = param_value[:-1]
|
||||||
|
param_dict[param_name] = self._convert_param_value(
|
||||||
|
param_value, param_name, param_config, function_name)
|
||||||
|
return ToolCall(
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=function_name,
|
||||||
|
arguments=json.dumps(param_dict, ensure_ascii=False)))
|
||||||
|
|
||||||
|
def _get_function_calls(self, model_output: str) -> List[str]:
|
||||||
|
matched_ranges = self.tool_call_regex.findall(model_output)
|
||||||
|
raw_tool_calls = [
|
||||||
|
match[0] if match[0] else match[1] for match in matched_ranges
|
||||||
|
]
|
||||||
|
if not raw_tool_calls:
|
||||||
|
raw_tool_calls = [model_output]
|
||||||
|
raw_function_calls: List[tuple] = []
|
||||||
|
for tool_call in raw_tool_calls:
|
||||||
|
raw_function_calls.extend(
|
||||||
|
self.tool_call_function_regex.findall(tool_call))
|
||||||
|
return [match[0] if match[0] else match[1]
|
||||||
|
for match in raw_function_calls]
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self, model_output: str,
|
||||||
|
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||||
|
if self.tool_call_prefix not in model_output:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
try:
|
||||||
|
function_calls = self._get_function_calls(model_output)
|
||||||
|
if not function_calls:
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
self._parse_xml_function_call(fc, request.tools)
|
||||||
|
for fc in function_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
self.prev_tool_call_arr.clear()
|
||||||
|
for tc in tool_calls:
|
||||||
|
self.prev_tool_call_arr.append({
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
})
|
||||||
|
|
||||||
|
content_index = model_output.find(self.tool_call_start_token)
|
||||||
|
idx = model_output.find(self.tool_call_prefix)
|
||||||
|
content_index = content_index if content_index >= 0 else idx
|
||||||
|
content = model_output[:content_index]
|
||||||
|
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=bool(tool_calls),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=content if content else None,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error extracting tool call from response.")
|
||||||
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
|
tool_calls=[],
|
||||||
|
content=model_output)
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
if not previous_text:
|
||||||
|
self._reset_streaming_state()
|
||||||
|
self.streaming_request = request
|
||||||
|
|
||||||
|
if not delta_text:
|
||||||
|
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
|
||||||
|
complete_calls = len(
|
||||||
|
self.tool_call_complete_regex.findall(current_text))
|
||||||
|
if complete_calls > 0 and self.prev_tool_call_arr:
|
||||||
|
open_calls = (
|
||||||
|
current_text.count(self.tool_call_start_token) -
|
||||||
|
current_text.count(self.tool_call_end_token))
|
||||||
|
if open_calls == 0:
|
||||||
|
return DeltaMessage(content="")
|
||||||
|
elif not self.is_tool_call_started and current_text:
|
||||||
|
return DeltaMessage(content="")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.accumulated_text = current_text
|
||||||
|
|
||||||
|
if self.json_closed and not self.in_function:
|
||||||
|
tool_ends = current_text.count(self.tool_call_end_token)
|
||||||
|
if tool_ends > self.current_tool_index:
|
||||||
|
self.current_tool_index += 1
|
||||||
|
self.header_sent = False
|
||||||
|
self.param_count = 0
|
||||||
|
self.json_started = False
|
||||||
|
self.json_closed = False
|
||||||
|
self.accumulated_params = {}
|
||||||
|
tool_starts = current_text.count(self.tool_call_start_token)
|
||||||
|
if self.current_tool_index >= tool_starts:
|
||||||
|
self.is_tool_call_started = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self.is_tool_call_started:
|
||||||
|
if (self.tool_call_start_token_id in delta_token_ids
|
||||||
|
or self.tool_call_start_token in delta_text):
|
||||||
|
self.is_tool_call_started = True
|
||||||
|
if self.tool_call_start_token in delta_text:
|
||||||
|
content_before = delta_text[:delta_text.index(
|
||||||
|
self.tool_call_start_token)]
|
||||||
|
if content_before:
|
||||||
|
return DeltaMessage(content=content_before)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if (current_text.rstrip().endswith(self.tool_call_end_token)
|
||||||
|
and delta_text.strip() == ""):
|
||||||
|
return None
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
tool_starts_count = current_text.count(self.tool_call_start_token)
|
||||||
|
if self.current_tool_index >= tool_starts_count:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Locate the current tool call's text slice
|
||||||
|
tool_start_positions: List[int] = []
|
||||||
|
search = 0
|
||||||
|
while True:
|
||||||
|
search = current_text.find(self.tool_call_start_token, search)
|
||||||
|
if search == -1:
|
||||||
|
break
|
||||||
|
tool_start_positions.append(search)
|
||||||
|
search += len(self.tool_call_start_token)
|
||||||
|
|
||||||
|
if self.current_tool_index >= len(tool_start_positions):
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_start_idx = tool_start_positions[self.current_tool_index]
|
||||||
|
tool_end_idx = current_text.find(self.tool_call_end_token,
|
||||||
|
tool_start_idx)
|
||||||
|
if tool_end_idx == -1:
|
||||||
|
tool_text = current_text[tool_start_idx:]
|
||||||
|
else:
|
||||||
|
tool_text = current_text[tool_start_idx:tool_end_idx +
|
||||||
|
len(self.tool_call_end_token)]
|
||||||
|
|
||||||
|
if not self.header_sent:
|
||||||
|
if self.tool_call_prefix in tool_text:
|
||||||
|
func_start = (tool_text.find(self.tool_call_prefix) +
|
||||||
|
len(self.tool_call_prefix))
|
||||||
|
func_end = tool_text.find(">", func_start)
|
||||||
|
if func_end != -1:
|
||||||
|
self.current_function_name = tool_text[func_start:func_end]
|
||||||
|
self.current_tool_id = self._generate_tool_call_id()
|
||||||
|
self.header_sent = True
|
||||||
|
self.in_function = True
|
||||||
|
self.prev_tool_call_arr.append({
|
||||||
|
"name": self.current_function_name,
|
||||||
|
"arguments": "{}",
|
||||||
|
})
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
id=self.current_tool_id,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=self.current_function_name,
|
||||||
|
arguments=""),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.in_function:
|
||||||
|
if not self.json_started:
|
||||||
|
self.json_started = True
|
||||||
|
self.streamed_args_for_tool[self.current_tool_index] += "{"
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments="{"),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Collect all complete parameters in one pass (speculative-decode safe)
|
||||||
|
param_starts: List[int] = []
|
||||||
|
search = 0
|
||||||
|
while True:
|
||||||
|
search = tool_text.find(self.parameter_prefix, search)
|
||||||
|
if search == -1:
|
||||||
|
break
|
||||||
|
param_starts.append(search)
|
||||||
|
search += len(self.parameter_prefix)
|
||||||
|
|
||||||
|
json_fragments: List[str] = []
|
||||||
|
while not self.in_param and self.param_count < len(param_starts):
|
||||||
|
param_idx = param_starts[self.param_count]
|
||||||
|
param_start = param_idx + len(self.parameter_prefix)
|
||||||
|
remaining = tool_text[param_start:]
|
||||||
|
|
||||||
|
if ">" not in remaining:
|
||||||
|
break
|
||||||
|
|
||||||
|
name_end = remaining.find(">")
|
||||||
|
current_param_name = remaining[:name_end]
|
||||||
|
value_start = param_start + name_end + 1
|
||||||
|
value_text = tool_text[value_start:]
|
||||||
|
if value_text.startswith("\n"):
|
||||||
|
value_text = value_text[1:]
|
||||||
|
|
||||||
|
param_end_idx = value_text.find(self.parameter_end_token)
|
||||||
|
if param_end_idx == -1:
|
||||||
|
next_param = value_text.find(self.parameter_prefix)
|
||||||
|
func_end = value_text.find(self.function_end_token)
|
||||||
|
if next_param != -1 and (func_end == -1
|
||||||
|
or next_param < func_end):
|
||||||
|
param_end_idx = next_param
|
||||||
|
elif func_end != -1:
|
||||||
|
param_end_idx = func_end
|
||||||
|
else:
|
||||||
|
tool_end_in_value = value_text.find(
|
||||||
|
self.tool_call_end_token)
|
||||||
|
if tool_end_in_value != -1:
|
||||||
|
param_end_idx = tool_end_in_value
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if param_end_idx == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
param_value = value_text[:param_end_idx]
|
||||||
|
if param_value.endswith("\n"):
|
||||||
|
param_value = param_value[:-1]
|
||||||
|
|
||||||
|
self.accumulated_params[current_param_name] = param_value
|
||||||
|
param_config = self._get_arguments_config(
|
||||||
|
self.current_function_name or "",
|
||||||
|
self.streaming_request.tools
|
||||||
|
if self.streaming_request else None)
|
||||||
|
converted = self._convert_param_value(
|
||||||
|
param_value, current_param_name, param_config,
|
||||||
|
self.current_function_name or "")
|
||||||
|
serialized = json.dumps(converted, ensure_ascii=False)
|
||||||
|
|
||||||
|
sep = "" if self.param_count == 0 else ", "
|
||||||
|
json_fragments.append(
|
||||||
|
f'{sep}"{current_param_name}": {serialized}')
|
||||||
|
self.param_count += 1
|
||||||
|
|
||||||
|
if json_fragments:
|
||||||
|
combined = "".join(json_fragments)
|
||||||
|
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_index] += combined
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||||
|
self.current_tool_index,
|
||||||
|
len(self.streamed_args_for_tool))
|
||||||
|
return DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments=combined),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Emit closing brace when </function> is seen (after params are done)
|
||||||
|
if not self.json_closed and self.function_end_token in tool_text:
|
||||||
|
self.json_closed = True
|
||||||
|
func_start = (tool_text.find(self.tool_call_prefix) +
|
||||||
|
len(self.tool_call_prefix))
|
||||||
|
func_content_end = tool_text.find(self.function_end_token,
|
||||||
|
func_start)
|
||||||
|
if func_content_end != -1:
|
||||||
|
try:
|
||||||
|
parsed_tool = self._parse_xml_function_call(
|
||||||
|
tool_text[func_start:func_content_end],
|
||||||
|
self.streaming_request.tools
|
||||||
|
if self.streaming_request else None)
|
||||||
|
if self.current_tool_index < len(
|
||||||
|
self.prev_tool_call_arr):
|
||||||
|
self.prev_tool_call_arr[
|
||||||
|
self.current_tool_index]["arguments"] = (
|
||||||
|
parsed_tool.function.arguments)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to parse tool call during "
|
||||||
|
"streaming: %s",
|
||||||
|
tool_text,
|
||||||
|
exc_info=True)
|
||||||
|
|
||||||
|
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||||
|
self.streamed_args_for_tool[
|
||||||
|
self.current_tool_index] += "}"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||||
|
self.current_tool_index,
|
||||||
|
len(self.streamed_args_for_tool))
|
||||||
|
|
||||||
|
result = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments="}"),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
self.in_function = False
|
||||||
|
self.accumulated_params = {}
|
||||||
|
return result
|
||||||
|
|
||||||
|
return None
|
||||||
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Reasoning parser module for vLLM 0.6.3 (BI-V100 / Qwen3.6-27B adaptation).
|
||||||
|
|
||||||
|
Usage: --reasoning-parser qwen3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
__all__ = ["ReasoningParser", "ReasoningParserManager"]
|
||||||
|
|
||||||
|
# Lazy-register Qwen3 parser; imported on first get_reasoning_parser("qwen3").
|
||||||
|
ReasoningParserManager.register_lazy(
|
||||||
|
"qwen3",
|
||||||
|
"vllm.reasoning.qwen3_reasoning_parser",
|
||||||
|
"Qwen3ReasoningParser",
|
||||||
|
)
|
||||||
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""
|
||||||
|
Abstract reasoning parser base classes for vLLM 0.6.3.
|
||||||
|
Adapted from vllm-original/vllm/reasoning/abs_reasoning_parsers.py:
|
||||||
|
- Removed vllm.entrypoints.mcp, vllm.utils.collection_utils, import_utils
|
||||||
|
- DeltaMessage from vllm 0.6.3 protocol path
|
||||||
|
- TokenizerLike -> AnyTokenizer
|
||||||
|
- ReasoningParserManager: simplified eager + lazy registration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
else:
|
||||||
|
DeltaMessage = Any
|
||||||
|
AnyTokenizer = Any
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParser:
|
||||||
|
"""Abstract base for all reasoning parsers."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||||
|
self.model_tokenizer = tokenizer
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def vocab(self) -> dict:
|
||||||
|
return self.model_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
|
"""Return True once the reasoning block has closed in input_ids."""
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||||
|
) -> bool:
|
||||||
|
return self.is_reasoning_end(input_ids)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_content_ids(self, input_ids: list) -> list:
|
||||||
|
"""Return token ids that belong to the content (post-reasoning) part."""
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
"""
|
||||||
|
Split a complete model output into (reasoning_text, content_text).
|
||||||
|
Either part may be None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Optional["DeltaMessage"]:
|
||||||
|
"""
|
||||||
|
Extract reasoning from a streaming delta.
|
||||||
|
Returns a DeltaMessage with reasoning_content and/or content set,
|
||||||
|
or None if this delta should be suppressed (control token).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseThinkingReasoningParser(ReasoningParser):
|
||||||
|
"""
|
||||||
|
Base for parsers that use <start_token>...</end_token> delimiters.
|
||||||
|
Subclasses define start_token / end_token properties.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def start_token(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def end_token(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||||
|
super().__init__(tokenizer, *args, **kwargs)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError("Tokenizer must be passed to ReasoningParser.")
|
||||||
|
if not self.start_token or not self.end_token:
|
||||||
|
raise ValueError("start_token and end_token must be defined.")
|
||||||
|
|
||||||
|
self.start_token_id: Optional[int] = self.vocab.get(self.start_token)
|
||||||
|
self.end_token_id: Optional[int] = self.vocab.get(self.end_token)
|
||||||
|
if self.start_token_id is None or self.end_token_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{self.__class__.__name__}: could not find think tokens "
|
||||||
|
f"'{self.start_token}'/'{self.end_token}' in tokenizer vocab."
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
|
for token_id in reversed(input_ids):
|
||||||
|
if token_id == self.start_token_id:
|
||||||
|
return False
|
||||||
|
if token_id == self.end_token_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||||
|
) -> bool:
|
||||||
|
return self.end_token_id in delta_ids
|
||||||
|
|
||||||
|
def extract_content_ids(self, input_ids: list) -> list:
|
||||||
|
if self.end_token_id not in input_ids[:-1]:
|
||||||
|
return []
|
||||||
|
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
count = 0
|
||||||
|
depth = 0
|
||||||
|
for tid in token_ids:
|
||||||
|
if tid == self.start_token_id:
|
||||||
|
depth += 1
|
||||||
|
elif tid == self.end_token_id:
|
||||||
|
if depth > 0:
|
||||||
|
depth -= 1
|
||||||
|
elif depth > 0:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
# Strip <think> if the model generated it (old-style template).
|
||||||
|
parts = model_output.partition(self.start_token)
|
||||||
|
model_output = parts[2] if parts[1] else parts[0]
|
||||||
|
|
||||||
|
if self.end_token not in model_output:
|
||||||
|
return model_output, None
|
||||||
|
reasoning, _, content = model_output.partition(self.end_token)
|
||||||
|
return reasoning, content or None
|
||||||
|
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Optional["DeltaMessage"]:
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage as _DeltaMessage
|
||||||
|
|
||||||
|
# Suppress lone control tokens.
|
||||||
|
if len(delta_token_ids) == 1 and delta_token_ids[0] in (
|
||||||
|
self.start_token_id, self.end_token_id
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
start_in_prev = self.start_token_id in previous_token_ids
|
||||||
|
start_in_delta = self.start_token_id in delta_token_ids
|
||||||
|
end_in_prev = self.end_token_id in previous_token_ids
|
||||||
|
end_in_delta = self.end_token_id in delta_token_ids
|
||||||
|
|
||||||
|
if start_in_prev:
|
||||||
|
if end_in_delta:
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
reasoning = delta_text[:end_idx] if end_idx >= 0 else ""
|
||||||
|
content = delta_text[end_idx + len(self.end_token):] if end_idx >= 0 else None
|
||||||
|
return _DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
elif end_in_prev:
|
||||||
|
return _DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
elif start_in_delta:
|
||||||
|
if end_in_delta:
|
||||||
|
start_idx = delta_text.find(self.start_token)
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
reasoning = delta_text[start_idx + len(self.start_token):end_idx]
|
||||||
|
content = delta_text[end_idx + len(self.end_token):]
|
||||||
|
return _DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParserManager:
|
||||||
|
"""
|
||||||
|
Registry for ReasoningParser implementations.
|
||||||
|
Supports eager and lazy registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_parsers: dict = {} # name -> class (eager)
|
||||||
|
_lazy: dict = {} # name -> (module_path, class_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_module(cls, name: str, parser_cls: type) -> None:
|
||||||
|
"""Eagerly register a ReasoningParser class."""
|
||||||
|
if not issubclass(parser_cls, ReasoningParser):
|
||||||
|
raise TypeError(f"{parser_cls} is not a ReasoningParser subclass.")
|
||||||
|
cls._parsers[name] = parser_cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_lazy(cls, name: str, module_path: str, class_name: str) -> None:
|
||||||
|
"""Register a parser for deferred import."""
|
||||||
|
cls._lazy[name] = (module_path, class_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reasoning_parser(cls, name: str) -> type:
|
||||||
|
if name in cls._parsers:
|
||||||
|
return cls._parsers[name]
|
||||||
|
if name in cls._lazy:
|
||||||
|
module_path, class_name = cls._lazy[name]
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
parser_cls = getattr(mod, class_name)
|
||||||
|
cls._parsers[name] = parser_cls
|
||||||
|
return parser_cls
|
||||||
|
registered = sorted(set(cls._parsers) | set(cls._lazy))
|
||||||
|
raise KeyError(
|
||||||
|
f"Reasoning parser '{name}' not found. "
|
||||||
|
f"Available: {registered}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_registered(cls) -> list:
|
||||||
|
return sorted(set(cls._parsers) | set(cls._lazy))
|
||||||
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Reasoning parser for Qwen3 / Qwen3.5 / Qwen3.6 model family.
|
||||||
|
Adapted from vllm-original/vllm/reasoning/qwen3_reasoning_parser.py.
|
||||||
|
|
||||||
|
The model uses <think>...</think> to wrap chain-of-thought output.
|
||||||
|
For Qwen3.5+ the chat template injects <think> into the prompt, so only
|
||||||
|
</think> appears in the generated tokens; older templates generate <think>
|
||||||
|
themselves. Both styles are handled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Sequence, Any
|
||||||
|
|
||||||
|
from vllm.reasoning.abs_reasoning_parsers import (
|
||||||
|
BaseThinkingReasoningParser,
|
||||||
|
ReasoningParserManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: Any, *args, **kwargs):
|
||||||
|
super().__init__(tokenizer, *args, **kwargs)
|
||||||
|
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||||
|
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return "<think>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
# Strip <think> if the model generated it (old template / edge case).
|
||||||
|
parts = model_output.partition(self.start_token)
|
||||||
|
model_output = parts[2] if parts[1] else parts[0]
|
||||||
|
|
||||||
|
if self.end_token not in model_output:
|
||||||
|
if not self.thinking_enabled:
|
||||||
|
return None, model_output
|
||||||
|
# Thinking enabled but output truncated before </think>.
|
||||||
|
return model_output, None
|
||||||
|
|
||||||
|
reasoning, _, content = model_output.partition(self.end_token)
|
||||||
|
return reasoning, content or None
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
token_ids = list(token_ids)
|
||||||
|
if self.start_token_id in token_ids:
|
||||||
|
# Old-style template: model generates <think> itself.
|
||||||
|
# Use depth-counting from the base class.
|
||||||
|
return super().count_reasoning_tokens(token_ids)
|
||||||
|
elif self.end_token_id in token_ids:
|
||||||
|
# New-style template (Qwen3.5+): <think> is injected into the
|
||||||
|
# prompt, so output starts already inside the thinking block.
|
||||||
|
# Every token before </think> is a reasoning token.
|
||||||
|
return token_ids.index(self.end_token_id)
|
||||||
|
else:
|
||||||
|
# No </think> in output: either truncated (all reasoning)
|
||||||
|
# or thinking disabled (none).
|
||||||
|
return len(token_ids) if self.thinking_enabled else 0
|
||||||
|
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
):
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||||
|
|
||||||
|
if not self.thinking_enabled:
|
||||||
|
return DeltaMessage(content=delta_text) if delta_text else None
|
||||||
|
|
||||||
|
# Strip <think> from delta if the model generates it itself.
|
||||||
|
if self.start_token_id in delta_token_ids:
|
||||||
|
start_idx = delta_text.find(self.start_token)
|
||||||
|
if start_idx >= 0:
|
||||||
|
delta_text = delta_text[start_idx + len(self.start_token):]
|
||||||
|
|
||||||
|
if self.end_token_id in delta_token_ids:
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
if end_idx >= 0:
|
||||||
|
reasoning = delta_text[:end_idx]
|
||||||
|
content = delta_text[end_idx + len(self.end_token):]
|
||||||
|
if not reasoning and not content:
|
||||||
|
return None
|
||||||
|
return DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not delta_text:
|
||||||
|
return None
|
||||||
|
elif self.end_token_id in previous_token_ids:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
|
||||||
|
# Register immediately when this module is imported.
|
||||||
|
ReasoningParserManager.register_module("qwen3", Qwen3ReasoningParser)
|
||||||
1385
qwen3_6_scripts/sequence.py
Normal file
1385
qwen3_6_scripts/sequence.py
Normal file
File diff suppressed because it is too large
Load Diff
999
qwen3_6_scripts/serving_chat.py
Normal file
999
qwen3_6_scripts/serving_chat.py
Normal file
@@ -0,0 +1,999 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||||
|
Optional)
|
||||||
|
from typing import Sequence as GenericSequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
|
apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_messages_futures)
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
|
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||||
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||||
|
ToolCall, UsageInfo)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
|
OpenAIServing,
|
||||||
|
PromptAdapterPath,
|
||||||
|
TextTokensPrompt)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.inputs import TokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
|
log_tracing_disabled_warning)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
base_model_paths: List[BaseModelPath],
|
||||||
|
response_role: str,
|
||||||
|
*,
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
enable_auto_tools: bool = False,
|
||||||
|
tool_parser: Optional[str] = None,
|
||||||
|
reasoning_parser: Optional[str] = None):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=base_model_paths,
|
||||||
|
lora_modules=lora_modules,
|
||||||
|
prompt_adapters=prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
|
||||||
|
self.response_role = response_role
|
||||||
|
self.use_tool_use_model_template = False
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
|
# set up tool use
|
||||||
|
self.enable_auto_tools: bool = enable_auto_tools
|
||||||
|
if self.enable_auto_tools:
|
||||||
|
logger.info(
|
||||||
|
"\"auto\" tool choice has been enabled please note that while"
|
||||||
|
" the parallel_tool_calls client option is preset for "
|
||||||
|
"compatibility reasons, it will be ignored.")
|
||||||
|
|
||||||
|
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||||
|
if self.enable_auto_tools:
|
||||||
|
try:
|
||||||
|
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||||
|
tool_parser)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
|
f"tool_parser:'{tool_parser}' which has not "
|
||||||
|
"been registered") from e
|
||||||
|
|
||||||
|
# set up reasoning parser
|
||||||
|
self.reasoning_parser_cls = None
|
||||||
|
if reasoning_parser:
|
||||||
|
try:
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
self.reasoning_parser_cls = \
|
||||||
|
ReasoningParserManager.get_reasoning_parser(reasoning_parser)
|
||||||
|
logger.info("Reasoning parser '%s' enabled.", reasoning_parser)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"Error: --reasoning-parser '{reasoning_parser}' could not "
|
||||||
|
"be loaded. Make sure vllm/reasoning/ is installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def create_chat_completion(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
raw_request: Optional[Request] = None,
|
||||||
|
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||||
|
ErrorResponse]:
|
||||||
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
for the API specification. This API mimics the OpenAI
|
||||||
|
ChatCompletion API.
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
logger.error("Error with model %s", error_check_ret)
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
|
request.messages, model_config, tokenizer)
|
||||||
|
|
||||||
|
tool_dicts = None if request.tools is None else [
|
||||||
|
tool.model_dump() for tool in request.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt: Union[str, List[int]]
|
||||||
|
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||||
|
if is_mistral_tokenizer:
|
||||||
|
prompt = apply_mistral_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
messages=request.messages,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in applying chat template from request")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
mm_data = await mm_data_future
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in loading multi-modal data")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# validation for OpenAI tools
|
||||||
|
# tool_choice = "required" is not supported
|
||||||
|
if request.tool_choice == "required":
|
||||||
|
return self.create_error_response(
|
||||||
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
|
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||||
|
self.enable_auto_tools and self.tool_parser is not None):
|
||||||
|
# for hf tokenizers, "auto" tools requires
|
||||||
|
# --enable-auto-tool-choice and --tool-call-parser
|
||||||
|
return self.create_error_response(
|
||||||
|
"\"auto\" tool choice requires "
|
||||||
|
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||||
|
|
||||||
|
request_id = f"chat-{random_uuid()}"
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.enable_auto_tools and self.tool_parser:
|
||||||
|
request = self.tool_parser(tokenizer).adjust_request(
|
||||||
|
request=request)
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(prompt, list) and isinstance(
|
||||||
|
prompt[0], int
|
||||||
|
), "Prompt has to be either a string or a list of token ids"
|
||||||
|
prompt_inputs = TextTokensPrompt(
|
||||||
|
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||||
|
|
||||||
|
assert prompt_inputs is not None
|
||||||
|
|
||||||
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
|
default_max_tokens = self.max_model_len - len(
|
||||||
|
prompt_inputs["prompt_token_ids"])
|
||||||
|
if request.use_beam_search:
|
||||||
|
sampling_params = request.to_beam_search_params(
|
||||||
|
default_max_tokens)
|
||||||
|
else:
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens)
|
||||||
|
|
||||||
|
self._log_inputs(request_id,
|
||||||
|
prompt_inputs,
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
|
engine_inputs = TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||||
|
if mm_data is not None:
|
||||||
|
engine_inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
is_tracing_enabled = (await
|
||||||
|
self.engine_client.is_tracing_enabled())
|
||||||
|
trace_headers = None
|
||||||
|
if is_tracing_enabled and raw_request:
|
||||||
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
|
if (not is_tracing_enabled and raw_request
|
||||||
|
and contains_trace_headers(raw_request.headers)):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
|
assert isinstance(self.engine_client,
|
||||||
|
(AsyncLLMEngine,
|
||||||
|
MQLLMEngineClient)), \
|
||||||
|
"Beam search is only supported with" \
|
||||||
|
"AsyncLLMEngine and MQLLMEngineClient."
|
||||||
|
result_generator = self.engine_client.beam_search(
|
||||||
|
engine_inputs['prompt_token_ids'],
|
||||||
|
request_id,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_generator = self.engine_client.generate(
|
||||||
|
engine_inputs,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
if raw_request:
|
||||||
|
result_generator = iterate_with_cancellation(
|
||||||
|
result_generator, raw_request.is_disconnected)
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if request.stream:
|
||||||
|
return self.chat_completion_stream_generator(
|
||||||
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.chat_completion_full_generator(
|
||||||
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||||
|
if request.add_generation_prompt:
|
||||||
|
return self.response_role
|
||||||
|
return request.messages[-1]["role"]
|
||||||
|
|
||||||
|
async def chat_completion_stream_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
model_name = self.base_model_paths[0].name
|
||||||
|
created_time = int(time.time())
|
||||||
|
chunk_object_type: Final = "chat.completion.chunk"
|
||||||
|
first_iteration = True
|
||||||
|
|
||||||
|
# Send response for each token for each request.n (index)
|
||||||
|
num_choices = 1 if request.n is None else request.n
|
||||||
|
previous_num_tokens = [0] * num_choices
|
||||||
|
finish_reason_sent = [False] * num_choices
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
|
else:
|
||||||
|
tool_choice_function_name = None
|
||||||
|
|
||||||
|
# Determine whether tools are in use with "auto" tool choice
|
||||||
|
tool_choice_auto = (
|
||||||
|
not tool_choice_function_name
|
||||||
|
and self._should_stream_with_auto_tool_parsing(request))
|
||||||
|
|
||||||
|
use_reasoning = self.reasoning_parser_cls is not None
|
||||||
|
|
||||||
|
all_previous_token_ids: Optional[List[List[int]]]
|
||||||
|
# previous_texts / all_previous_token_ids are needed for both tool
|
||||||
|
# parsing and reasoning parsing (both require full-history context).
|
||||||
|
if tool_choice_auto or use_reasoning:
|
||||||
|
previous_texts = [""] * num_choices
|
||||||
|
all_previous_token_ids = [[]] * num_choices
|
||||||
|
else:
|
||||||
|
previous_texts, all_previous_token_ids = None, None
|
||||||
|
|
||||||
|
# Prepare the tool parser if it's needed
|
||||||
|
try:
|
||||||
|
if tool_choice_auto and self.tool_parser:
|
||||||
|
tool_parsers: List[Optional[ToolParser]] = [
|
||||||
|
self.tool_parser(tokenizer)
|
||||||
|
] * num_choices
|
||||||
|
else:
|
||||||
|
tool_parsers = [None] * num_choices
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in tool parser creation: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare reasoning parsers (one instance per choice for state isolation)
|
||||||
|
reasoning_parsers: List[Optional[object]] = [None] * num_choices
|
||||||
|
reasoning_end_arr: List[bool] = [False] * num_choices
|
||||||
|
reasoning_token_counts: List[int] = [0] * num_choices
|
||||||
|
if use_reasoning:
|
||||||
|
try:
|
||||||
|
reasoning_parsers = [
|
||||||
|
self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
for _ in range(num_choices)
|
||||||
|
]
|
||||||
|
# If thinking is disabled per-request, mark reasoning as
|
||||||
|
# already ended so the tool-auto branch is reachable.
|
||||||
|
for idx, rp in enumerate(reasoning_parsers):
|
||||||
|
if hasattr(rp, 'thinking_enabled') and not rp.thinking_enabled:
|
||||||
|
reasoning_end_arr[idx] = True
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in reasoning parser creation: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
if res.prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
if res.encoder_prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||||
|
|
||||||
|
# We need to do it here, because if there are exceptions in
|
||||||
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
|
# response (by the try...catch).
|
||||||
|
if first_iteration:
|
||||||
|
# Send first response for each request.n (index) with
|
||||||
|
# the role
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
|
||||||
|
# NOTE num_choices defaults to 1 so this usually executes
|
||||||
|
# once per request
|
||||||
|
for i in range(num_choices):
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
role=role,
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
|
||||||
|
# if usage should be included
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
# if continuous usage stats are requested, add it
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=num_prompt_tokens)
|
||||||
|
chunk.usage = usage
|
||||||
|
# otherwise don't
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response to echo the input portion of the
|
||||||
|
# last message
|
||||||
|
if request.echo or request.continue_final_message:
|
||||||
|
last_msg_content: str = ""
|
||||||
|
if conversation and "content" in conversation[
|
||||||
|
-1] and conversation[-1].get("role") == role:
|
||||||
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
|
||||||
|
if last_msg_content:
|
||||||
|
for i in range(num_choices):
|
||||||
|
choice_data = (
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
content=last_msg_content),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None))
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
if (request.stream_options and
|
||||||
|
request.stream_options.include_usage):
|
||||||
|
if (request.stream_options.
|
||||||
|
continuous_usage_stats):
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=num_prompt_tokens)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(
|
||||||
|
exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
first_iteration = False
|
||||||
|
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
|
|
||||||
|
if finish_reason_sent[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert output.logprobs is not None, (
|
||||||
|
"Did not output logprobs")
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=output.token_ids,
|
||||||
|
top_logprobs=output.logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
delta_text = output.text
|
||||||
|
delta_message: Optional[DeltaMessage]
|
||||||
|
|
||||||
|
# Maintain text/token history when either reasoning or
|
||||||
|
# auto-tool parsing is active.
|
||||||
|
assert previous_texts is not None or not (
|
||||||
|
tool_choice_auto or use_reasoning)
|
||||||
|
if previous_texts is not None:
|
||||||
|
assert all_previous_token_ids is not None
|
||||||
|
previous_text = previous_texts[i]
|
||||||
|
previous_token_ids = all_previous_token_ids[i]
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
current_token_ids = previous_token_ids + list(
|
||||||
|
output.token_ids)
|
||||||
|
previous_texts[i] = current_text
|
||||||
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
else:
|
||||||
|
previous_text = ""
|
||||||
|
previous_token_ids = []
|
||||||
|
current_text = delta_text
|
||||||
|
current_token_ids = list(output.token_ids)
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with named tool_choice
|
||||||
|
if tool_choice_function_name:
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(function=DeltaFunctionCall(
|
||||||
|
name=tool_choice_function_name,
|
||||||
|
arguments=delta_text),
|
||||||
|
index=i)
|
||||||
|
])
|
||||||
|
|
||||||
|
# handle reasoning: route through reasoning parser while
|
||||||
|
# </think> has not yet been seen.
|
||||||
|
elif use_reasoning and not reasoning_end_arr[i]:
|
||||||
|
r_parser = reasoning_parsers[i]
|
||||||
|
delta_message = r_parser.extract_reasoning_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=previous_token_ids,
|
||||||
|
current_token_ids=current_token_ids,
|
||||||
|
delta_token_ids=output.token_ids,
|
||||||
|
)
|
||||||
|
# Mark reasoning as ended when end token appears.
|
||||||
|
if r_parser.end_token_id in current_token_ids:
|
||||||
|
reasoning_end_arr[i] = True
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with "auto" tool choice
|
||||||
|
# (only reached after reasoning block, if any, has ended)
|
||||||
|
elif tool_choice_auto:
|
||||||
|
assert tool_parser is not None
|
||||||
|
delta_message = (
|
||||||
|
tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=previous_token_ids,
|
||||||
|
current_token_ids=current_token_ids,
|
||||||
|
delta_token_ids=output.token_ids,
|
||||||
|
request=request))
|
||||||
|
|
||||||
|
# handle streaming just a content delta
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
# set the previous values for the next iteration
|
||||||
|
previous_num_tokens[i] += len(output.token_ids)
|
||||||
|
|
||||||
|
# if the message delta is None (e.g. because it was a
|
||||||
|
# "control token" for tool calls or the parser otherwise
|
||||||
|
# wasn't ready to send a token, then
|
||||||
|
# get the next token without streaming a chunk.
|
||||||
|
# However, if this is the finish token we must NOT skip —
|
||||||
|
# the finish block updates reasoning_token_counts, sets
|
||||||
|
# finish_reason_sent, and flushes the final usage chunk.
|
||||||
|
if delta_message is None:
|
||||||
|
if output.finish_reason is None:
|
||||||
|
continue
|
||||||
|
delta_message = DeltaMessage()
|
||||||
|
|
||||||
|
if output.finish_reason is None:
|
||||||
|
# Send token-by-token response for each request.n
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=None)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
|
||||||
|
# handle usage stats if requested & if continuous
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# if the model is finished generating
|
||||||
|
else:
|
||||||
|
# check to make sure we haven't "forgotten" to stream
|
||||||
|
# any tokens that were generated but previously
|
||||||
|
# matched by partial json parsing
|
||||||
|
# only happens if we are NOT using guided decoding
|
||||||
|
auto_tools_called = False
|
||||||
|
if tool_parser:
|
||||||
|
auto_tools_called = len(
|
||||||
|
tool_parser.prev_tool_call_arr) > 0
|
||||||
|
index = len(tool_parser.prev_tool_call_arr
|
||||||
|
) - 1 if auto_tools_called else 0
|
||||||
|
else:
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||||
|
delta_message, output) and tool_parser:
|
||||||
|
# get the expected call based on partial JSON
|
||||||
|
# parsing which "autocompletes" the JSON
|
||||||
|
expected_call = json.dumps(
|
||||||
|
tool_parser.prev_tool_call_arr[index].get(
|
||||||
|
"arguments", {}))
|
||||||
|
|
||||||
|
# get what we've streamed so far for arguments
|
||||||
|
# for the current tool
|
||||||
|
actual_call = tool_parser.streamed_args_for_tool[
|
||||||
|
index]
|
||||||
|
|
||||||
|
# check to see if there's anything left to stream
|
||||||
|
remaining_call = expected_call.replace(
|
||||||
|
actual_call, "", 1)
|
||||||
|
|
||||||
|
# set that as a delta message
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=remaining_call).
|
||||||
|
model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Count reasoning tokens for this choice at finish time.
|
||||||
|
if use_reasoning and all_previous_token_ids is not None:
|
||||||
|
r_parser = reasoning_parsers[i]
|
||||||
|
reasoning_token_counts[i] = \
|
||||||
|
r_parser.count_reasoning_tokens(
|
||||||
|
all_previous_token_ids[i])
|
||||||
|
|
||||||
|
# Send the finish response for each request.n only once
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason
|
||||||
|
if not auto_tools_called else "tool_calls",
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
finish_reason_sent[i] = True
|
||||||
|
|
||||||
|
# once the final token is handled, if stream_options.include_usage
|
||||||
|
# is sent, send the usage
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
completion_tokens = previous_num_tokens[i]
|
||||||
|
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||||
|
final_usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + completion_tokens,
|
||||||
|
reasoning_tokens=total_reasoning,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[],
|
||||||
|
model=model_name,
|
||||||
|
usage=final_usage)
|
||||||
|
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True))
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
|
num_completion_tokens = sum(previous_num_tokens)
|
||||||
|
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||||
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||||
|
reasoning_tokens=total_reasoning)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
logger.error("error in chat completion stream generator: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def chat_completion_full_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
|
model_name = self.base_model_paths[0].name
|
||||||
|
created_time = int(time.time())
|
||||||
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
final_res = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
|
||||||
|
assert final_res is not None
|
||||||
|
|
||||||
|
choices: List[ChatCompletionResponseChoice] = []
|
||||||
|
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
for output in final_res.outputs:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=out_logprobs,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
# In the OpenAI API the finish_reason is "tools_called"
|
||||||
|
# if the tool choice is auto and the model produced a tool
|
||||||
|
# call. The same is not true for named function calls
|
||||||
|
auto_tools_called = False
|
||||||
|
|
||||||
|
# Extract reasoning content if parser is configured.
|
||||||
|
# output_text is what remains after stripping <think>...</think>.
|
||||||
|
reasoning_text: Optional[str] = None
|
||||||
|
output_text: str = output.text
|
||||||
|
if self.reasoning_parser_cls:
|
||||||
|
r_parser = self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
reasoning_text, extracted = r_parser.extract_reasoning(
|
||||||
|
output.text, request)
|
||||||
|
output_text = extracted or ""
|
||||||
|
|
||||||
|
# if auto tools are not enabled, and a named tool choice using
|
||||||
|
# outlines is not being used
|
||||||
|
if (not self.enable_auto_tools
|
||||||
|
or not self.tool_parser) and not isinstance(
|
||||||
|
request.tool_choice,
|
||||||
|
ChatCompletionNamedToolChoiceParam):
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# if the request uses tools and specified a tool choice
|
||||||
|
elif request.tool_choice and type(
|
||||||
|
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name=request.tool_choice.function.name,
|
||||||
|
arguments=output_text))
|
||||||
|
])
|
||||||
|
|
||||||
|
# if the request doesn't use tool choice
|
||||||
|
# OR specifies to not use a tool
|
||||||
|
elif not request.tool_choice or request.tool_choice == "none":
|
||||||
|
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# handle when there are tools and tool choice is auto
|
||||||
|
elif request.tools and (
|
||||||
|
request.tool_choice == "auto"
|
||||||
|
or request.tool_choice is None) and self.enable_auto_tools \
|
||||||
|
and self.tool_parser:
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in tool parser creation: %s", e)
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Parse tool calls from the post-reasoning content.
|
||||||
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
|
output_text, request=request)
|
||||||
|
auto_tools_called = tool_call_info.tools_called
|
||||||
|
if tool_call_info.tools_called:
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=tool_call_info.content,
|
||||||
|
tool_calls=tool_call_info.tool_calls)
|
||||||
|
else:
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# undetermined case that is still important to handle
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Error in chat_completion_full_generator - cannot determine"
|
||||||
|
" if tools should be extracted. Returning a standard chat "
|
||||||
|
"completion.")
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
message=message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason="tool_calls" if auto_tools_called else
|
||||||
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
if request.echo or request.continue_final_message:
|
||||||
|
last_msg_content = ""
|
||||||
|
if conversation and "content" in conversation[-1] and conversation[
|
||||||
|
-1].get("role") == role:
|
||||||
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
full_message = last_msg_content + (choice.message.content
|
||||||
|
or "")
|
||||||
|
choice.message.content = full_message
|
||||||
|
|
||||||
|
assert final_res.prompt_token_ids is not None
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
if final_res.encoder_prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(
|
||||||
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
|
total_reasoning_tokens: Optional[int] = None
|
||||||
|
if self.reasoning_parser_cls:
|
||||||
|
rp = self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
total_reasoning_tokens = sum(
|
||||||
|
rp.count_reasoning_tokens(list(output.token_ids))
|
||||||
|
for output in final_res.outputs)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
reasoning_tokens=total_reasoning_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
prompt_logprobs=final_res.prompt_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_top_logprobs(
|
||||||
|
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||||
|
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||||
|
return [
|
||||||
|
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||||
|
p[1],
|
||||||
|
p[0],
|
||||||
|
tokenizer,
|
||||||
|
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||||
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
|
bytes=list(
|
||||||
|
token.encode("utf-8", errors="replace")))
|
||||||
|
for i, p in enumerate(logprobs.items())
|
||||||
|
if top_logprobs and i < top_logprobs
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_chat_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
) -> ChatCompletionLogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None:
|
||||||
|
token = tokenizer.decode(token_id)
|
||||||
|
if self.return_tokens_as_token_ids:
|
||||||
|
token = f"token_id:{token_id}"
|
||||||
|
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=token,
|
||||||
|
bytes=list(token.encode("utf-8", errors="replace")),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
step_token = step_top_logprobs[token_id]
|
||||||
|
step_decoded = step_token.decoded_token
|
||||||
|
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=self._get_decoded_token(
|
||||||
|
step_token,
|
||||||
|
token_id,
|
||||||
|
tokenizer,
|
||||||
|
self.return_tokens_as_token_ids,
|
||||||
|
),
|
||||||
|
logprob=max(step_token.logprob, -9999.0),
|
||||||
|
bytes=None if step_decoded is None else list(
|
||||||
|
step_decoded.encode("utf-8", errors="replace")),
|
||||||
|
top_logprobs=self._get_top_logprobs(
|
||||||
|
step_top_logprobs,
|
||||||
|
num_output_top_logprobs,
|
||||||
|
tokenizer,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
|
|
||||||
|
def _should_stream_with_auto_tool_parsing(self,
|
||||||
|
request: ChatCompletionRequest):
|
||||||
|
"""
|
||||||
|
Utility function to check if streamed tokens should go through the tool
|
||||||
|
call parser that was configured.
|
||||||
|
|
||||||
|
We only want to do this IF user-provided tools are set, a tool parser
|
||||||
|
is configured, "auto" tool choice is enabled, and the request's tool
|
||||||
|
choice field indicates that "auto" tool choice should be used.
|
||||||
|
"""
|
||||||
|
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||||
|
and request.tool_choice in ['auto', None])
|
||||||
|
|
||||||
|
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||||
|
self,
|
||||||
|
delta_message: Optional[DeltaMessage],
|
||||||
|
output: CompletionOutput,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check to see if we should check for unstreamed tool arguments tokens.
|
||||||
|
This is only applicable when auto tool parsing is enabled, the delta
|
||||||
|
is a tool call with arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
return bool(
|
||||||
|
# if there is a delta message that includes tool calls which
|
||||||
|
# include a function that has arguments
|
||||||
|
output.finish_reason is not None
|
||||||
|
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||||
|
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||||
|
and delta_message.tool_calls[0].function
|
||||||
|
and delta_message.tool_calls[0].function.arguments is not None
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user