[Feat]Xlite Qwen3-vl Support (#5228)

### What this PR does / why we need it?
This patch adds support for the Qwen3-VL model in Xlite. For more
details about Xlite, please refer to the following
link:https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md.
The latest performance comparison data between xlite and the default
aclgraph mode is as follows:

### Does this PR introduce _any_ user-facing change?
XLite graph mode supports the Qwen3-VL model.

### How was this patch tested?
vLLM version: v0.12.0 

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: lvjunqi <lvjunqi1@huawei.com>
Co-authored-by: lvjunqi <lvjunqi1@huawei.com>
This commit is contained in:
lvjunqi
2025-12-22 16:30:52 +08:00
committed by GitHub
parent 78aa7f2693
commit 55beac9c91
4 changed files with 19 additions and 9 deletions

View File

@@ -49,7 +49,7 @@ The details of each configuration option are as follows:
**xlite_graph_config** **xlite_graph_config**
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- | | ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. | | `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama, Qwen dense series models, and Qwen3-vl are supported. |
| `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. | | `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. |
**weight_prefetch_config** **weight_prefetch_config**

View File

@@ -12,7 +12,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa
There are two kinds for graph mode supported by vLLM Ascend: There are two kinds for graph mode supported by vLLM Ascend:
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested. - **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported. - **XliteGraph**: This is the openeuler xlite graph mode. In v0.11.0, only Llama, Qwen dense series models, and Qwen3-vl are supported.
## Using ACLGraph ## Using ACLGraph
ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough. ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough.
@@ -36,7 +36,7 @@ vllm serve Qwen/Qwen2-7B-Instruct
## Using XliteGraph ## Using XliteGraph
If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config. If you want to run Llama, Qwen dense series models, or Qwen3-vl with xlite graph mode, please install xlite, and set xlite_graph_config.
```bash ```bash
pip install xlite pip install xlite
@@ -59,7 +59,7 @@ Online example:
vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}' vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}'
``` ```
You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md) You can find more details abort xlite [here](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md)
## Fallback to the Eager Mode ## Fallback to the Eager Mode

View File

@@ -281,7 +281,7 @@ class NPUPlatform(Platform):
parallel_config.all2all_backend = "flashinfer_all2allv" parallel_config.all2all_backend = "flashinfer_all2allv"
if ascend_config.xlite_graph_config.enabled: if ascend_config.xlite_graph_config.enabled:
logger.info( logger.info(
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite" "openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite"
) )
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
else: else:

View File

@@ -48,16 +48,23 @@ class LlamaXliteModel(XliteModel):
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]: vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype dtype = vllm_config.model_config.dtype
params_dict = dict(runnable.named_parameters()) params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"):
layers = runnable.language_model.model.layers
model_prefix = "language_model."
else:
layers = runnable.model.layers layers = runnable.model.layers
model_prefix = ""
config = self._build_model_config(vllm_config) config = self._build_model_config(vllm_config)
xlite_model = Model() xlite_model = Model()
xlite_model.embed = params_dict.get("model.embed_tokens.weight") xlite_model.embed = params_dict.get(model_prefix +
xlite_model.norm = params_dict.get("model.norm.weight") "model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_config.tie_word_embeddings: if vllm_config.model_config.hf_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed xlite_model.head = xlite_model.embed
else: else:
xlite_model.head = params_dict.get("lm_head.weight") xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [ xlite_model.attn_norm = [
layer.input_layernorm.weight for layer in layers layer.input_layernorm.weight for layer in layers
] ]
@@ -112,6 +119,8 @@ class LlamaXliteModel(XliteModel):
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig: def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_config hf_config = vllm_config.model_config.hf_config
if hasattr(hf_config, "text_config"):
hf_config = hf_config.text_config
config = ModelConfig() config = ModelConfig()
config.vocab_size = hf_config.vocab_size config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size config.hidden_size = hf_config.hidden_size
@@ -166,6 +175,7 @@ def xlite_model_init(
"LlamaForCausalLM": LlamaXliteModel, "LlamaForCausalLM": LlamaXliteModel,
"Qwen2ForCausalLM": LlamaXliteModel, "Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel, "Qwen3ForCausalLM": LlamaXliteModel,
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
} }
architecture = vllm_config.model_config.architectures[0] architecture = vllm_config.model_config.architectures[0]