[Feat] Add Euler xlite graph wrapper support (#4526)

### What this PR does / why we need it?
This patch adds support for the xlite graph wrapper to vllm_ascend.
Xlite provides operator implementations of the transformer network on
Ascend hardware. For details about xlite, please refer to the following
link: https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md
The latest performance comparison data between xlite and the default
aclgraph mode is as follows:

## Qwen3 32B TPS 910B3(A2) Online Inference Performance Comparison
- aclgraph: main(c4a71fc6) 
- xlite-full: main(c4a71fc6) + xlite-full
- xlite-decode-only: main(c4a71fc6) + xlite-decode-only
- diff1: Performance comparison between xlite-full and aclgraph
- diff2: Performance comparison between xlite-decode-only and aclgraph


### Does this PR introduce _any_ user-facing change?
Enable the xlite graph mode by setting xlite_graph_config:
--additional-config='{"xlite_graph_config": {"enabled": true}}' #
Enabled for decode only
--additional-config='{"xlite_graph_config": {"enabled": true,
"full_mode": true}}' # Enabled for prefill and decode

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: lulina <lina.lulina@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
LuLina
2025-12-08 08:27:46 +08:00
committed by GitHub
parent 8fdb689a32
commit 2be0fe2691
13 changed files with 553 additions and 3 deletions

View File

@@ -103,6 +103,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_sampler.py
pytest -sv tests/e2e/singlecard/test_vlm.py
pytest -sv tests/e2e/singlecard/multi-modal/test_internvl.py
pytest -sv tests/e2e/singlecard/test_xlite.py
# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

View File

@@ -26,6 +26,7 @@ The following table lists additional configuration options available in vLLM Asc
| Name | Type | Default | Description |
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode |
| `torchair_graph_config` | dict | `{}` | Configuration options for torchair graph mode |
| `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch |
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
@@ -45,6 +46,12 @@ The following table lists additional configuration options available in vLLM Asc
The details of each configuration option are as follows:
**xlite_graph_config**
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. |
| `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. |
**torchair_graph_config**
| Name | Type | Default | Description |

View File

@@ -10,9 +10,10 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P
From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
There are two kinds for graph mode supported by vLLM Ascend:
There are three kinds for graph mode supported by vLLM Ascend:
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested.
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported.
- **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported.
## Using ACLGraph
ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough.
@@ -57,9 +58,36 @@ vllm serve path/to/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config
You can find more details about additional configuration [here](../configuration/additional_config.md).
## Using XliteGraph
If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config.
```bash
pip install xlite
```
Offline example:
```python
import os
from vllm import LLM
# xlite supports the decode-only mode by default, and the full mode can be enabled by setting: "full_mode": True
model = LLM(model="path/to/Qwen3-32B", tensor_parallel_size=8, additional_config={"xlite_graph_config": {"enabled": True, "full_mode": True}})
outputs = model.generate("Hello, how are you?")
```
Online example:
```shell
vllm serve path/to/Qwen3-32B --tensor-parallel-size 8 --additional-config='{"xlite_graph_config": {"enabled": true, "full_mode": true}}'
```
You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/blob/master/xlite/README.md)
## Fallback to the Eager Mode
If both `ACLGraph` and `TorchAirGraph` fail to run, you should fallback to the eager mode.
If `ACLGraph`, `TorchAirGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode.
Offline example:

View File

@@ -27,3 +27,6 @@ ignore_missing_imports = True
[mypy-msprobe.*]
ignore_missing_imports = True
allow_untyped_imports = True
[mypy-xlite.*]
ignore_missing_imports = True

View File

@@ -20,4 +20,5 @@ soundfile
pytest_mock
msserviceprofiler>=1.2.2
mindstudio-probe>=8.3.0
arctic-inference==0.1.1
arctic-inference==0.1.1
xlite

View File

@@ -0,0 +1,130 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without xlite.
Run `pytest tests/e2e/singlecard/test_xlite.py`.
"""
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-0.6B",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models_with_xlite_decode_only(
model: str,
max_tokens: int,
) -> None:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
block_size=128,
max_model_len=1024,
enforce_eager=False,
additional_config={"xlite_graph_config": {
"enabled": True
}},
) as runner:
vllm_xlite_outputs = runner.model.generate(prompts, sampling_params)
with VllmRunner(
model,
block_size=128,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_xlite_outputs_list = []
for output in vllm_xlite_outputs:
vllm_xlite_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_xlite_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_xlite_outputs",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models_with_xlite_full_mode(
model: str,
max_tokens: int,
) -> None:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
block_size=128,
max_model_len=1024,
enforce_eager=False,
additional_config={
"xlite_graph_config": {
"enabled": True,
"full_mode": True
}
},
) as runner:
vllm_xlite_outputs = runner.model.generate(prompts, sampling_params)
with VllmRunner(
model,
block_size=128,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_xlite_outputs_list = []
for output in vllm_xlite_outputs:
vllm_xlite_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_xlite_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_xlite_outputs",
)

View File

@@ -32,6 +32,7 @@ class TestNPUPlatform(TestBase):
def mock_vllm_ascend_config():
mock_ascend_config = MagicMock()
mock_ascend_config.torchair_graph_config.enabled = False
mock_ascend_config.xlite_graph_config.enabled = False
mock_ascend_config.enable_shared_expert_dp = False
return mock_ascend_config
@@ -512,6 +513,16 @@ class TestNPUPlatform(TestBase):
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.xlite_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
vllm_config.parallel_config.worker_cls = "auto"
self.platform.check_and_update_config(vllm_config)
self.assertEqual(
vllm_config.parallel_config.worker_cls,
"vllm_ascend.xlite.xlite_worker.XliteWorker",
)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch('vllm_ascend.utils.get_ascend_device_type',

View File

@@ -72,6 +72,10 @@ class AscendConfig:
self.torchair_graph_config = TorchairGraphConfig(
torchair_graph_config, vllm_config, additional_config)
xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
vllm_config)
ascend_compilation_config = additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(
@@ -291,6 +295,29 @@ class TorchairGraphConfig:
)
class XliteGraphConfig:
"""
Configuration Object for xlite_graph_config from additional_config
"""
def __init__(self, xlite_graph_config, vllm_config):
self.enabled = xlite_graph_config.get("enabled", False)
self.full_mode = xlite_graph_config.get("full_mode", False)
if self.enabled:
if bool(vllm_config.speculative_config):
raise RuntimeError(
"Xlite graph mode is not compatible with speculative decoding. Please disable speculative decoding."
)
if vllm_config.parallel_config.pipeline_parallel_size > 1:
raise RuntimeError(
"Xlite graph mode is not compatible with pipeline parallelism. Please set pipeline_parallel_size to 1."
)
if vllm_config.cache_config.block_size != 128:
raise RuntimeError(
"Xlite graph mode is only compatible with block_size of 128. Please set block_size to 128."
)
class DumpConfig:
"""
Configuration object for dump/PrecisionDebugger settings.

View File

@@ -305,6 +305,11 @@ class NPUPlatform(Platform):
parallel_config.all2all_backend = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
elif ascend_config.xlite_graph_config.enabled:
logger.info(
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite"
)
parallel_config.worker_cls = "vllm_ascend.xlite.xlite_worker.XliteWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"

View File

275
vllm_ascend/xlite/xlite.py Normal file
View File

@@ -0,0 +1,275 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.utils import is_enable_nz
class XliteModel:
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
raise NotImplementedError(
"Xlite Model initialize function not implemented.")
class LlamaXliteModel(XliteModel):
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
params_dict = dict(runnable.named_parameters())
layers = runnable.model.layers
config = self._build_model_config(vllm_config)
xlite_model = Model()
xlite_model.embed = params_dict.get("model.embed_tokens.weight")
xlite_model.norm = params_dict.get("model.norm.weight")
if vllm_config.model_config.hf_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed
else:
xlite_model.head = params_dict.get("lm_head.weight")
xlite_model.attn_norm = [
layer.input_layernorm.weight for layer in layers
]
xlite_model.attn_out = [
layer.self_attn.o_proj.weight for layer in layers
]
xlite_model.mha_qkv = [
layer.self_attn.qkv_proj.weight for layer in layers
]
xlite_model.mlp_norm = [
layer.post_attention_layernorm.weight for layer in layers
]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers
]
xlite_model.mlp_down = [layer.mlp.down_proj.weight for layer in layers]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias")
and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [
layer.self_attn.q_norm.weight for layer in layers
if hasattr(layer.self_attn, "q_norm")
]
k_norm = [
layer.self_attn.k_norm.weight for layer in layers
if hasattr(layer.self_attn, "k_norm")
]
if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False
else:
config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
config.qk_norm = False
else:
config.qk_norm = True
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = self._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_config
config = ModelConfig()
config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size
config.n_layers = hf_config.num_hidden_layers
config.n_heads = hf_config.num_attention_heads
config.n_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, "head_dim"):
config.head_dim = hf_config.head_dim
else:
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
config.rope_head_dim = config.head_dim
config.norm_eps = hf_config.rms_norm_eps
config.rope_theta = hf_config.rope_theta
config.softmax_scale = config.head_dim**-0.5
config.n_dense_layers = hf_config.num_hidden_layers
config.intermediate_size = hf_config.intermediate_size
config.def_tp_size = get_tensor_model_parallel_world_size()
config.def_dp_size = 1
config.moe_ep_size = 1
config.moe_tp_size = 1
config.attn_type = AttnMHA
config.weight_nz = is_enable_nz()
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len
config.max_m = scheduler_config.max_num_batched_tokens
config.max_batch_size = max_batch_size
config.max_seq_len = max_seq_len
config.block_size = vllm_config.cache_config.block_size
return config
def _precompute_freqs_cis(self,
dim: int,
end: int,
dtype: torch.dtype,
theta: float = 10000.0):
freqs = 1.0 / (theta**(torch.arange(
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
cos_cache = freqs.cos().to(dtype)
sin_cache = freqs.sin().to(dtype)
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
return freq_cis.to(device='npu')
def xlite_model_init(
runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
strategy_map = {
"LlamaForCausalLM": LlamaXliteModel,
"Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel,
}
architecture = vllm_config.model_config.architectures[0]
strategy_class = strategy_map.get(architecture)
if not strategy_class:
raise ValueError(f"{architecture} not supported!")
return strategy_class().initialize(runnable, vllm_config)
class XliteWrapper:
"""
xlite graph wrapper
"""
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
self.runnable = runnable
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank
self.xlite_rt = Runtime(local_rank, 0, rank,
get_tensor_model_parallel_world_size())
(self.xlite_model, self.freq_cis, hidden_size,
dtype) = xlite_model_init(runnable, vllm_config)
rt_pool_size = self.xlite_model.get_tensor_pool_size()
if rank == 0:
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
raise ValueError(
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
)
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_states = torch.empty(max_num_tokens,
hidden_size,
device=f"npu:{local_rank}",
dtype=dtype)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"xlite wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def register_kv_caches(self, kv_caches: Any):
self.kv_caches = kv_caches
def __call__(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
list[torch.Tensor]]:
forward_context = get_forward_context()
attn_metadata: Any = forward_context.attn_metadata
if attn_metadata is None:
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is None or not isinstance(attn_metadata,
AscendMetadata):
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if not with_prefill or self.full_mode:
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
seq_lens = attn_metadata.seq_lens[:batch]
query_lens = attn_metadata.query_lens[:batch]
cached_lens = seq_lens - query_lens
xlite_attn_metadata = ModelAttnMeta()
xlite_attn_metadata.lens = query_lens.tolist()
xlite_attn_metadata.cached_lens = cached_lens.tolist()
xlite_attn_metadata.is_prefills = [
False
] * attn_metadata.num_decodes + [True] * attn_metadata.num_prefills
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
).tolist()
h = self.hidden_states[:attn_metadata.num_actual_tokens]
stream = torch.npu.current_stream().npu_stream
if inputs_embeds is None:
self.xlite_model.forward(self.xlite_rt, input_ids,
xlite_attn_metadata, self.kv_caches,
self.freq_cis, h, stream)
else:
self.xlite_model.forward_with_inputs_embeds(
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
self.kv_caches, self.freq_cis, h, stream)
return h
else:
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)

View File

@@ -0,0 +1,36 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# isort: skip_file
import torch.nn as nn
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class XliteModelRunner(NPUModelRunner):
def get_model(self) -> nn.Module:
return self.model.unwrap()
def load_model(self) -> None:
super().load_model()
from vllm_ascend.xlite.xlite import XliteWrapper
self.model = XliteWrapper(self.model, self.vllm_config)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
super().initialize_kv_cache(kv_cache_config)
self.model.register_kv_caches(self.kv_caches)

View File

@@ -0,0 +1,26 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.worker.worker_v1 import NPUWorker
from vllm_ascend.xlite.xlite_model_runner import XliteModelRunner
class XliteWorker(NPUWorker):
"""Xlite worker bases on NPUWorker. Only xlite specified code should be added in this class."""
def init_device(self):
"""Override init_device to init xlite model runner"""
self.device = self._init_device()
self.model_runner = XliteModelRunner(self.vllm_config, self.device)