diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 8f91f43e..a1130e72 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -49,7 +49,7 @@ The details of each configuration option are as follows: **xlite_graph_config** | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | -| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. | +| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama, Qwen dense series models, and Qwen3-vl are supported. | | `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. | **weight_prefetch_config** diff --git a/docs/source/user_guide/feature_guide/graph_mode.md b/docs/source/user_guide/feature_guide/graph_mode.md index 76868949..a74d59df 100644 --- a/docs/source/user_guide/feature_guide/graph_mode.md +++ b/docs/source/user_guide/feature_guide/graph_mode.md @@ -12,7 +12,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa There are two kinds for graph mode supported by vLLM Ascend: - **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested. -- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported. +- **XliteGraph**: This is the openeuler xlite graph mode. In v0.11.0, only Llama, Qwen dense series models, and Qwen3-vl are supported. ## Using ACLGraph ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough. @@ -36,7 +36,7 @@ vllm serve Qwen/Qwen2-7B-Instruct ## Using XliteGraph -If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config. +If you want to run Llama, Qwen dense series models, or Qwen3-vl with xlite graph mode, please install xlite, and set xlite_graph_config. ```bash pip install xlite @@ -59,7 +59,7 @@ Online example: vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}' ``` -You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md) +You can find more details abort xlite [here](https://atomgit.com/openeuler/GVirt/blob/master/xlite/README.md) ## Fallback to the Eager Mode diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2a70932d..7f2eb793 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -281,7 +281,7 @@ class NPUPlatform(Platform): parallel_config.all2all_backend = "flashinfer_all2allv" if ascend_config.xlite_graph_config.enabled: logger.info( - "Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite" + "openEuler Xlite enabled. See: https://atomgit.com/openeuler/GVirt/tree/master/xlite" ) parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker" else: diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index c41734a4..462052d7 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -48,16 +48,23 @@ class LlamaXliteModel(XliteModel): vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]: dtype = vllm_config.model_config.dtype params_dict = dict(runnable.named_parameters()) - layers = runnable.model.layers + + if hasattr(runnable, "language_model"): + layers = runnable.language_model.model.layers + model_prefix = "language_model." + else: + layers = runnable.model.layers + model_prefix = "" config = self._build_model_config(vllm_config) xlite_model = Model() - xlite_model.embed = params_dict.get("model.embed_tokens.weight") - xlite_model.norm = params_dict.get("model.norm.weight") + xlite_model.embed = params_dict.get(model_prefix + + "model.embed_tokens.weight") + xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight") if vllm_config.model_config.hf_config.tie_word_embeddings: xlite_model.head = xlite_model.embed else: - xlite_model.head = params_dict.get("lm_head.weight") + xlite_model.head = params_dict.get(model_prefix + "lm_head.weight") xlite_model.attn_norm = [ layer.input_layernorm.weight for layer in layers ] @@ -112,6 +119,8 @@ class LlamaXliteModel(XliteModel): def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig: hf_config = vllm_config.model_config.hf_config + if hasattr(hf_config, "text_config"): + hf_config = hf_config.text_config config = ModelConfig() config.vocab_size = hf_config.vocab_size config.hidden_size = hf_config.hidden_size @@ -166,6 +175,7 @@ def xlite_model_init( "LlamaForCausalLM": LlamaXliteModel, "Qwen2ForCausalLM": LlamaXliteModel, "Qwen3ForCausalLM": LlamaXliteModel, + "Qwen3VLForConditionalGeneration": LlamaXliteModel, } architecture = vllm_config.model_config.architectures[0]