Files
xc-llm-ascend/docs/source/tutorials/models/Qwen3_reranker.md
SILONG ZENG a1f321a556 [Doc]Refresh model tutorial examples and serving commands (#7426)
### What this PR does / why we need it?
Main updates include:
- update model IDs and default model paths in serving / offline
inference examples

- adjust some command snippets and notes for better copy-paste usability

- replace `SamplingParams` argument usage from `max_completion_tokens`
to `max_tokens`(**Offline** inference currently **does not support** the
"max_completion_tokens")
``` bash
Traceback (most recent call last):
  File "/vllm-workspace/vllm-ascend/qwen-next.py", line 18, in <module>
    sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=40, max_completion_tokens=32)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Unexpected keyword argument 'max_completion_tokens'
[ERROR] 2026-03-17-09:57:40 (PID:276, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception
```

- refresh **Qwen3-Omni-30B-A3B-Thinking** recommended environment
variable
``` bash
export HCCL_BUFFSIZE=512
export HCCL_OP_EXPANSION_MODE=AIV
```
``` bash
EZ9999[PID: 25038] 2026-03-17-08:21:12.001.372 (EZ9999):  HCCL_BUFFSIZE is too SMALL, maxBs = 256, h = 2048, 
epWorldSize = 2, localMoeExpertNum = 64, sharedExpertNum = 0, tokenNeedSizeDispatch = 4608, tokenNeedSizeCombine 
= 4096, k = 8, NEEDED_HCCL_BUFFSIZE(((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + 
(maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2) = 305MB, HCCL_BUFFSIZE=200MB.
[FUNC:CheckWinSize][FILE:moe_distribute_dispatch_v2_tiling.cpp][LINE:984]
```

- fix **Qwen3-reranker** example usage to match the current **pooling
runner** interface and score output access
``` python
model = LLM(
    model=model_name,
    task="score",       # need fix
    hf_overrides={
        "architectures": ["Qwen3ForSequenceClassification"],
        "classifier_from_token": ["no", "yes"],
```
--->
``` python
model = LLM(
    model=model_name,
    runner="pooling",
    hf_overrides={
        "architectures": ["Qwen3ForSequenceClassification"],
        "classifier_from_token": ["no", "yes"],
```

- modify **PaddleOCR-VL**  parameter `TASK_QUEUE_ENABLE` from `2` to `1`
``` bash
(EngineCore_DP0 pid=26273) RuntimeError: NPUModelRunner init failed, error is NPUModelRunner failed, error
 is Do not support TASK_QUEUE_ENABLE = 2 during NPU graph capture, please export TASK_QUEUE_ENABLE=1/0.
```

These changes are needed because several documentation examples had
drifted from the current runtime behavior and recommended invocation
patterns, which could confuse users when following the tutorials
directly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4497431df6

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-03-20 11:34:18 +08:00

8.3 KiB
Raw Blame History

Qwen3-Reranker

Introduction

The Qwen3 Reranker model series is the latest proprietary model of the Qwen family, specifically designed for text embedding and ranking tasks. Building upon the dense foundational models of the Qwen3 series, it provides a comprehensive range of text embeddings and reranking models in various sizes (0.6B, 4B, and 8B). This guide describes how to run the model with vLLM Ascend. Note that only 0.9.2rc1 and higher versions of vLLM Ascend support the model.

Supported Features

Refer to supported features to get the model's supported feature matrix.

Environment Preparation

Model Weight

It is recommended to download the model weight to the shared directory of multiple nodes, such as /root/.cache/

Installation

You can use our official docker image to run Qwen3-Reranker series models.

  • Start the docker image on your node, refer to using docker.

if you don't want to use the docker image as above, you can also build all from source:

Deployment

Using the Qwen3-Reranker-8B model as an example, first run the docker container with the following command:

Online Inference

vllm serve Qwen/Qwen3-Reranker-8B --host 127.0.0.1 --port 8888 --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

Once your server is started, you can send request with follow examples.

requests demo + formatting query & document

import requests

url = "http://127.0.0.1:8888/v1/rerank"

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"

instruction = (
    "Given a web search query, retrieve relevant passages that answer the query"
)

query = "What is the capital of China?"

documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

documents = [
    document_template.format(doc=doc, suffix=suffix) for doc in documents
]

response = requests.post(url,
                         json={
                             "query": query_template.format(prefix=prefix, instruction=instruction, query=query),
                             "documents": documents,
                         }).json()

print(response)

If you run this script successfully, you will see a list of scores printed to the console, similar to this:

{'id': 'rerank-e856a17c954047a3a40f73d5ec43dbc6', 'model': 'Qwen/Qwen3-Reranker-8B', 'usage': {'total_tokens': 193}, 'results': [{'index': 0, 'document': {'text': '<Document>: The capital of China is Beijing.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n', 'multi_modal': None}, 'relevance_score': 0.9944348335266113}, {'index': 1, 'document': {'text': '<Document>: Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n', 'multi_modal': None}, 'relevance_score': 6.700084327349032e-07}]}

Offline Inference

from vllm import LLM

model_name = "Qwen/Qwen3-Reranker-8B"

# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. Seehttps://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# model = LLM(model="Qwen/Qwen3-Reranker-8B", task="score")

# If you want to load the official original version, the init parameters are
# as follows.

model = LLM(
    model=model_name,
    runner="pooling",
    hf_overrides={
        "architectures": ["Qwen3ForSequenceClassification"],
        "classifier_from_token": ["no", "yes"],
        "is_original_qwen3_reranker": True,
    },
)

# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector.  The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"

if __name__ == "__main__":
    instruction = (
        "Given a web search query, retrieve relevant passages that answer the query"
    )

    query = "What is the capital of China?"

    documents = [
        "The capital of China is Beijing.",
        "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
    ]

    documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]

    outputs = model.score(query_template.format(prefix=prefix, instruction=instruction, query=query), documents)

    print([output.outputs.score for output in outputs])

If you run this script successfully, you will see a list of scores printed to the console, similar to this:

[0.9943699240684509, 6.876250040477316e-07]

Performance

Run performance of Qwen3-Reranker-8B as an example. Refer to vllm benchmark for more details.

Take the serve as an example. Run the code as follows.

vllm bench serve --model Qwen3-Reranker-8B --backend vllm-rerank --dataset-name random-rerank --host 127.0.0.1 --port 8888 --endpoint /v1/rerank --tokenizer /root/.cache/Qwen3-Reranker-8B  --random-input 200  --save-result --result-dir ./

After about several minutes, you can get the performance evaluation result. With this tutorial, the performance result is:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  6.78
Total input tokens:                      108032
Request throughput (req/s):              31.11
Total Token throughput (tok/s):          15929.35
----------------End-to-end Latency----------------
Mean E2EL (ms):                          4422.79
Median E2EL (ms):                        4412.58
P99 E2EL (ms):                           6294.52
==================================================