[Feat] Add Euler xlite graph wrapper support (#4526)
### What this PR does / why we need it? This patch adds support for the xlite graph wrapper to vllm_ascend. Xlite provides operator implementations of the transformer network on Ascend hardware. For details about xlite, please refer to the following link: https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md The latest performance comparison data between xlite and the default aclgraph mode is as follows: ## Qwen3 32B TPS 910B3(A2) Online Inference Performance Comparison - aclgraph: main(c4a71fc6) - xlite-full: main(c4a71fc6) + xlite-full - xlite-decode-only: main(c4a71fc6) + xlite-decode-only - diff1: Performance comparison between xlite-full and aclgraph - diff2: Performance comparison between xlite-decode-only and aclgraph ### Does this PR introduce _any_ user-facing change? Enable the xlite graph mode by setting xlite_graph_config: --additional-config='{"xlite_graph_config": {"enabled": true}}' # Enabled for decode only --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}' # Enabled for prefill and decode - vLLM version: v0.12.0 - vLLM main:ad32e3e19c--------- Signed-off-by: lulina <lina.lulina@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -103,6 +103,7 @@ jobs:
|
||||
pytest -sv tests/e2e/singlecard/test_sampler.py
|
||||
pytest -sv tests/e2e/singlecard/test_vlm.py
|
||||
pytest -sv tests/e2e/singlecard/multi-modal/test_internvl.py
|
||||
pytest -sv tests/e2e/singlecard/test_xlite.py
|
||||
|
||||
# ------------------------------------ v1 spec decode test ------------------------------------ #
|
||||
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
|
||||
|
||||
@@ -26,6 +26,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode |
|
||||
| `torchair_graph_config` | dict | `{}` | Configuration options for torchair graph mode |
|
||||
| `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch |
|
||||
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
|
||||
@@ -45,6 +46,12 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
|
||||
The details of each configuration option are as follows:
|
||||
|
||||
**xlite_graph_config**
|
||||
| Name | Type | Default | Description |
|
||||
| ---- | ---- | ------- | ----------- |
|
||||
| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. |
|
||||
| `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. |
|
||||
|
||||
**torchair_graph_config**
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|
||||
@@ -10,9 +10,10 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P
|
||||
|
||||
From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
|
||||
|
||||
There are two kinds for graph mode supported by vLLM Ascend:
|
||||
There are three kinds for graph mode supported by vLLM Ascend:
|
||||
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
|
||||
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported.
|
||||
- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported.
|
||||
|
||||
## Using ACLGraph
|
||||
ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough.
|
||||
@@ -57,9 +58,36 @@ vllm serve path/to/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config
|
||||
|
||||
You can find more details about additional configuration [here](../configuration/additional_config.md).
|
||||
|
||||
## Using XliteGraph
|
||||
|
||||
If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config.
|
||||
|
||||
```bash
|
||||
pip install xlite
|
||||
```
|
||||
|
||||
Offline example:
|
||||
|
||||
```python
|
||||
import os
|
||||
from vllm import LLM
|
||||
|
||||
# xlite supports the decode-only mode by default, and the full mode can be enabled by setting: "full_mode": True
|
||||
model = LLM(model="path/to/Qwen3-32B", tensor_parallel_size=8, additional_config={"xlite_graph_config": {"enabled": True, "full_mode": True}})
|
||||
outputs = model.generate("Hello, how are you?")
|
||||
```
|
||||
|
||||
Online example:
|
||||
|
||||
```shell
|
||||
vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}'
|
||||
```
|
||||
|
||||
You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md)
|
||||
|
||||
## Fallback to the Eager Mode
|
||||
|
||||
If both `ACLGraph` and `TorchAirGraph` fail to run, you should fallback to the eager mode.
|
||||
If `ACLGraph`, `TorchAirGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
|
||||
|
||||
Offline example:
|
||||
|
||||
|
||||
3
mypy.ini
3
mypy.ini
@@ -27,3 +27,6 @@ ignore_missing_imports = True
|
||||
[mypy-msprobe.*]
|
||||
ignore_missing_imports = True
|
||||
allow_untyped_imports = True
|
||||
|
||||
[mypy-xlite.*]
|
||||
ignore_missing_imports = True
|
||||
@@ -20,4 +20,5 @@ soundfile
|
||||
pytest_mock
|
||||
msserviceprofiler>=1.2.2
|
||||
mindstudio-probe>=8.3.0
|
||||
arctic-inference==0.1.1
|
||||
arctic-inference==0.1.1
|
||||
xlite
|
||||
130
tests/e2e/singlecard/test_xlite.py
Normal file
130
tests/e2e/singlecard/test_xlite.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Compare the outputs of vLLM with and without xlite.
|
||||
|
||||
Run `pytest tests/e2e/singlecard/test_xlite.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen3-0.6B",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
def test_models_with_xlite_decode_only(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
|
||||
with VllmRunner(
|
||||
model,
|
||||
block_size=128,
|
||||
max_model_len=1024,
|
||||
enforce_eager=False,
|
||||
additional_config={"xlite_graph_config": {
|
||||
"enabled": True
|
||||
}},
|
||||
) as runner:
|
||||
vllm_xlite_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
block_size=128,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
vllm_xlite_outputs_list = []
|
||||
for output in vllm_xlite_outputs:
|
||||
vllm_xlite_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_xlite_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_xlite_outputs",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
def test_models_with_xlite_full_mode(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
|
||||
with VllmRunner(
|
||||
model,
|
||||
block_size=128,
|
||||
max_model_len=1024,
|
||||
enforce_eager=False,
|
||||
additional_config={
|
||||
"xlite_graph_config": {
|
||||
"enabled": True,
|
||||
"full_mode": True
|
||||
}
|
||||
},
|
||||
) as runner:
|
||||
vllm_xlite_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
block_size=128,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
vllm_xlite_outputs_list = []
|
||||
for output in vllm_xlite_outputs:
|
||||
vllm_xlite_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_xlite_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_xlite_outputs",
|
||||
)
|
||||
@@ -32,6 +32,7 @@ class TestNPUPlatform(TestBase):
|
||||
def mock_vllm_ascend_config():
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.torchair_graph_config.enabled = False
|
||||
mock_ascend_config.xlite_graph_config.enabled = False
|
||||
mock_ascend_config.enable_shared_expert_dp = False
|
||||
return mock_ascend_config
|
||||
|
||||
@@ -512,6 +513,16 @@ class TestNPUPlatform(TestBase):
|
||||
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
|
||||
)
|
||||
|
||||
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
test_ascend_config.xlite_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = test_ascend_config
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertEqual(
|
||||
vllm_config.parallel_config.worker_cls,
|
||||
"vllm_ascend.xlite.xlite_worker.XliteWorker",
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
|
||||
@@ -72,6 +72,10 @@ class AscendConfig:
|
||||
self.torchair_graph_config = TorchairGraphConfig(
|
||||
torchair_graph_config, vllm_config, additional_config)
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
|
||||
vllm_config)
|
||||
|
||||
ascend_compilation_config = additional_config.get(
|
||||
"ascend_compilation_config", {})
|
||||
self.ascend_compilation_config = AscendCompilationConfig(
|
||||
@@ -291,6 +295,29 @@ class TorchairGraphConfig:
|
||||
)
|
||||
|
||||
|
||||
class XliteGraphConfig:
|
||||
"""
|
||||
Configuration Object for xlite_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, xlite_graph_config, vllm_config):
|
||||
self.enabled = xlite_graph_config.get("enabled", False)
|
||||
self.full_mode = xlite_graph_config.get("full_mode", False)
|
||||
if self.enabled:
|
||||
if bool(vllm_config.speculative_config):
|
||||
raise RuntimeError(
|
||||
"Xlite graph mode is not compatible with speculative decoding. Please disable speculative decoding."
|
||||
)
|
||||
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
||||
raise RuntimeError(
|
||||
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
|
||||
)
|
||||
if vllm_config.cache_config.block_size != 128:
|
||||
raise RuntimeError(
|
||||
"Xlite graph mode is only compatible with block_size of 128. Please set block_size to 128."
|
||||
)
|
||||
|
||||
|
||||
class DumpConfig:
|
||||
"""
|
||||
Configuration object for dump/PrecisionDebugger settings.
|
||||
|
||||
@@ -305,6 +305,11 @@ class NPUPlatform(Platform):
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
logger.info(
|
||||
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite"
|
||||
)
|
||||
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||
|
||||
|
||||
0
vllm_ascend/xlite/__init__.py
Normal file
0
vllm_ascend/xlite/__init__.py
Normal file
275
vllm_ascend/xlite/xlite.py
Normal file
275
vllm_ascend/xlite/xlite.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_world_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
|
||||
|
||||
class XliteModel:
|
||||
|
||||
def initialize(
|
||||
self, runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
raise NotImplementedError(
|
||||
"Xlite Model initialize function not implemented.")
|
||||
|
||||
|
||||
class LlamaXliteModel(XliteModel):
|
||||
|
||||
def initialize(
|
||||
self, runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
dtype = vllm_config.model_config.dtype
|
||||
params_dict = dict(runnable.named_parameters())
|
||||
layers = runnable.model.layers
|
||||
|
||||
config = self._build_model_config(vllm_config)
|
||||
xlite_model = Model()
|
||||
xlite_model.embed = params_dict.get("model.embed_tokens.weight")
|
||||
xlite_model.norm = params_dict.get("model.norm.weight")
|
||||
if vllm_config.model_config.hf_config.tie_word_embeddings:
|
||||
xlite_model.head = xlite_model.embed
|
||||
else:
|
||||
xlite_model.head = params_dict.get("lm_head.weight")
|
||||
xlite_model.attn_norm = [
|
||||
layer.input_layernorm.weight for layer in layers
|
||||
]
|
||||
xlite_model.attn_out = [
|
||||
layer.self_attn.o_proj.weight for layer in layers
|
||||
]
|
||||
xlite_model.mha_qkv = [
|
||||
layer.self_attn.qkv_proj.weight for layer in layers
|
||||
]
|
||||
xlite_model.mlp_norm = [
|
||||
layer.post_attention_layernorm.weight for layer in layers
|
||||
]
|
||||
xlite_model.mlp_up_gate = [
|
||||
layer.mlp.gate_up_proj.weight for layer in layers
|
||||
]
|
||||
xlite_model.mlp_down = [layer.mlp.down_proj.weight for layer in layers]
|
||||
mha_qkv_bias = [
|
||||
layer.self_attn.qkv_proj.bias for layer in layers
|
||||
if hasattr(layer.self_attn.qkv_proj, "bias")
|
||||
and layer.self_attn.qkv_proj.bias is not None
|
||||
]
|
||||
q_norm = [
|
||||
layer.self_attn.q_norm.weight for layer in layers
|
||||
if hasattr(layer.self_attn, "q_norm")
|
||||
]
|
||||
k_norm = [
|
||||
layer.self_attn.k_norm.weight for layer in layers
|
||||
if hasattr(layer.self_attn, "k_norm")
|
||||
]
|
||||
|
||||
if len(mha_qkv_bias) != config.n_layers:
|
||||
config.qkv_bias = False
|
||||
else:
|
||||
config.qkv_bias = True
|
||||
xlite_model.mha_qkv_bias = mha_qkv_bias
|
||||
|
||||
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
|
||||
config.qk_norm = False
|
||||
else:
|
||||
config.qk_norm = True
|
||||
xlite_model.mha_q_norm = q_norm
|
||||
xlite_model.mha_k_norm = k_norm
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
xlite_model.init(config, rank)
|
||||
|
||||
freq_cis = self._precompute_freqs_cis(config.head_dim,
|
||||
config.max_seq_len, dtype,
|
||||
config.rope_theta)
|
||||
|
||||
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
||||
|
||||
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
config = ModelConfig()
|
||||
config.vocab_size = hf_config.vocab_size
|
||||
config.hidden_size = hf_config.hidden_size
|
||||
config.n_layers = hf_config.num_hidden_layers
|
||||
config.n_heads = hf_config.num_attention_heads
|
||||
config.n_kv_heads = hf_config.num_key_value_heads
|
||||
if hasattr(hf_config, "head_dim"):
|
||||
config.head_dim = hf_config.head_dim
|
||||
else:
|
||||
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
|
||||
config.rope_head_dim = config.head_dim
|
||||
config.norm_eps = hf_config.rms_norm_eps
|
||||
config.rope_theta = hf_config.rope_theta
|
||||
config.softmax_scale = config.head_dim**-0.5
|
||||
config.n_dense_layers = hf_config.num_hidden_layers
|
||||
config.intermediate_size = hf_config.intermediate_size
|
||||
config.def_tp_size = get_tensor_model_parallel_world_size()
|
||||
config.def_dp_size = 1
|
||||
config.moe_ep_size = 1
|
||||
config.moe_tp_size = 1
|
||||
|
||||
config.attn_type = AttnMHA
|
||||
config.weight_nz = is_enable_nz()
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
max_batch_size = scheduler_config.max_num_seqs
|
||||
max_seq_len = vllm_config.model_config.max_model_len
|
||||
config.max_m = scheduler_config.max_num_batched_tokens
|
||||
config.max_batch_size = max_batch_size
|
||||
config.max_seq_len = max_seq_len
|
||||
config.block_size = vllm_config.cache_config.block_size
|
||||
return config
|
||||
|
||||
def _precompute_freqs_cis(self,
|
||||
dim: int,
|
||||
end: int,
|
||||
dtype: torch.dtype,
|
||||
theta: float = 10000.0):
|
||||
freqs = 1.0 / (theta**(torch.arange(
|
||||
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
cos_cache = freqs.cos().to(dtype)
|
||||
sin_cache = freqs.sin().to(dtype)
|
||||
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
|
||||
return freq_cis.to(device='npu')
|
||||
|
||||
|
||||
def xlite_model_init(
|
||||
runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
strategy_map = {
|
||||
"LlamaForCausalLM": LlamaXliteModel,
|
||||
"Qwen2ForCausalLM": LlamaXliteModel,
|
||||
"Qwen3ForCausalLM": LlamaXliteModel,
|
||||
}
|
||||
|
||||
architecture = vllm_config.model_config.architectures[0]
|
||||
strategy_class = strategy_map.get(architecture)
|
||||
if not strategy_class:
|
||||
raise ValueError(f"{architecture} not supported!")
|
||||
return strategy_class().initialize(runnable, vllm_config)
|
||||
|
||||
|
||||
class XliteWrapper:
|
||||
"""
|
||||
xlite graph wrapper
|
||||
"""
|
||||
|
||||
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
|
||||
self.runnable = runnable
|
||||
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
local_rank = get_world_group().local_rank
|
||||
self.xlite_rt = Runtime(local_rank, 0, rank,
|
||||
get_tensor_model_parallel_world_size())
|
||||
|
||||
(self.xlite_model, self.freq_cis, hidden_size,
|
||||
dtype) = xlite_model_init(runnable, vllm_config)
|
||||
|
||||
rt_pool_size = self.xlite_model.get_tensor_pool_size()
|
||||
if rank == 0:
|
||||
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
|
||||
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
|
||||
raise ValueError(
|
||||
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
|
||||
)
|
||||
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_states = torch.empty(max_num_tokens,
|
||||
hidden_size,
|
||||
device=f"npu:{local_rank}",
|
||||
dtype=dtype)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"xlite wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def register_kv_caches(self, kv_caches: Any):
|
||||
self.kv_caches = kv_caches
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
|
||||
list[torch.Tensor]]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: Any = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
attn_metadata = next(iter(attn_metadata.values()), None)
|
||||
if attn_metadata is None or not isinstance(attn_metadata,
|
||||
AscendMetadata):
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
with_prefill = attn_metadata.attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
if not with_prefill or self.full_mode:
|
||||
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
|
||||
seq_lens = attn_metadata.seq_lens[:batch]
|
||||
query_lens = attn_metadata.query_lens[:batch]
|
||||
cached_lens = seq_lens - query_lens
|
||||
|
||||
xlite_attn_metadata = ModelAttnMeta()
|
||||
xlite_attn_metadata.lens = query_lens.tolist()
|
||||
xlite_attn_metadata.cached_lens = cached_lens.tolist()
|
||||
xlite_attn_metadata.is_prefills = [
|
||||
False
|
||||
] * attn_metadata.num_decodes + [True] * attn_metadata.num_prefills
|
||||
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
|
||||
).tolist()
|
||||
|
||||
h = self.hidden_states[:attn_metadata.num_actual_tokens]
|
||||
stream = torch.npu.current_stream().npu_stream
|
||||
if inputs_embeds is None:
|
||||
self.xlite_model.forward(self.xlite_rt, input_ids,
|
||||
xlite_attn_metadata, self.kv_caches,
|
||||
self.freq_cis, h, stream)
|
||||
else:
|
||||
self.xlite_model.forward_with_inputs_embeds(
|
||||
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
|
||||
self.kv_caches, self.freq_cis, h, stream)
|
||||
return h
|
||||
else:
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
36
vllm_ascend/xlite/xlite_model_runner.py
Normal file
36
vllm_ascend/xlite/xlite_model_runner.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
import torch.nn as nn
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class XliteModelRunner(NPUModelRunner):
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model.unwrap()
|
||||
|
||||
def load_model(self) -> None:
|
||||
super().load_model()
|
||||
from vllm_ascend.xlite.xlite import XliteWrapper
|
||||
self.model = XliteWrapper(self.model, self.vllm_config)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
super().initialize_kv_cache(kv_cache_config)
|
||||
self.model.register_kv_caches(self.kv_caches)
|
||||
26
vllm_ascend/xlite/xlite_worker.py
Normal file
26
vllm_ascend/xlite/xlite_worker.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
from vllm_ascend.xlite.xlite_model_runner import XliteModelRunner
|
||||
|
||||
|
||||
class XliteWorker(NPUWorker):
|
||||
"""Xlite worker bases on NPUWorker. Only xlite specified code should be added in this class."""
|
||||
|
||||
def init_device(self):
|
||||
"""Override init_device to init xlite model runner"""
|
||||
self.device = self._init_device()
|
||||
self.model_runner = XliteModelRunner(self.vllm_config, self.device)
|
||||
Reference in New Issue
Block a user