Files
xc-llm-ascend/vllm_ascend/xlite/xlite.py
Li Wang 484e7c59dc [CI] optimize lint term (#5986)
### What this PR does / why we need it?
This patch purpose to optimize the lint check term. The main idea is to
reduce unnecessary installation time.
1. The installation of vllm is not must, only append the path of vllm
src to the `PATHONPATH` is effective
2. This installation of `requirements-dev.txt` is not must, we have a
pre-built image `quay.io/ascend-ci/vllm-ascend:lint` with all the
requirements installed in advance.
**NOTE**: the conditions for triggering image builds are: 1).Daily
scheduled build; 2) Build when requirements are modified; 3) Manual
build. This ensures that the dependencies in our image are up-to-date to
the greatest extent possible.
3. The `mypy` was separated from the `pre-commit` hook for performance
reasons; we found that integrating `mypy` into the `pre-commit` hook
resulted in poor performance.
4. Reduce the CPU core consumption from 16 -> 8

### Does this PR introduce _any_ user-facing change?
The end-to-end lint time was optimized from 20min/per PR to 8min/per PR
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2026-01-22 15:46:59 +08:00

370 lines
16 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import (get_ep_group,
get_tensor_model_parallel_world_size,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime, # type: ignore[attr-defined]
ScoringFuncSoftmax)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
class XliteModel:
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
raise NotImplementedError(
"Xlite Model initialize function not implemented.")
class LlamaXliteModel(XliteModel):
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = self._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
hf_config = vllm_config.model_config.hf_text_config
if hasattr(hf_config, "text_config"):
hf_config = hf_config.text_config
config = ModelConfig()
config.vocab_size = hf_config.vocab_size
config.hidden_size = hf_config.hidden_size
config.n_layers = hf_config.num_hidden_layers
config.n_heads = hf_config.num_attention_heads
config.n_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, "head_dim"):
config.head_dim = hf_config.head_dim
else:
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
config.rope_head_dim = config.head_dim
config.norm_eps = hf_config.rms_norm_eps
config.rope_theta = hf_config.rope_theta
config.softmax_scale = config.head_dim**-0.5
config.n_dense_layers = hf_config.num_hidden_layers
config.intermediate_size = hf_config.intermediate_size
config.def_tp_size = get_tensor_model_parallel_world_size()
config.def_dp_size = 1
config.moe_ep_size = 1
config.moe_tp_size = 1
config.attn_type = AttnMHA
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len
config.max_m = scheduler_config.max_num_batched_tokens
config.max_batch_size = max_batch_size
config.max_seq_len = max_seq_len
config.block_size = vllm_config.cache_config.block_size
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
params_dict = dict(runnable.named_parameters())
if hasattr(runnable, "language_model"):
layers = runnable.language_model.model.layers
model_prefix = "language_model."
else:
layers = runnable.model.layers
model_prefix = ""
xlite_model = Model()
xlite_model.embed = params_dict.get(model_prefix +
"model.embed_tokens.weight")
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
xlite_model.head = xlite_model.embed
else:
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
xlite_model.attn_norm = [
layer.input_layernorm.weight for layer in layers
]
xlite_model.attn_out = [
layer.self_attn.o_proj.weight for layer in layers
]
xlite_model.mha_qkv = [
layer.self_attn.qkv_proj.weight for layer in layers
]
xlite_model.mlp_norm = [
layer.post_attention_layernorm.weight for layer in layers
]
xlite_model.mlp_up_gate = [
layer.mlp.gate_up_proj.weight for layer in layers
if hasattr(layer.mlp, "gate_up_proj")
and layer.mlp.gate_up_proj.weight is not None
]
xlite_model.mlp_down = [
layer.mlp.down_proj.weight for layer in layers
if hasattr(layer.mlp, "down_proj")
and layer.mlp.down_proj.weight is not None
]
mha_qkv_bias = [
layer.self_attn.qkv_proj.bias for layer in layers
if hasattr(layer.self_attn.qkv_proj, "bias")
and layer.self_attn.qkv_proj.bias is not None
]
q_norm = [
layer.self_attn.q_norm.weight for layer in layers
if hasattr(layer.self_attn, "q_norm")
]
k_norm = [
layer.self_attn.k_norm.weight for layer in layers
if hasattr(layer.self_attn, "k_norm")
]
if len(mha_qkv_bias) != config.n_layers:
config.qkv_bias = False
else:
config.qkv_bias = True
xlite_model.mha_qkv_bias = mha_qkv_bias
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
config.qk_norm = False
else:
config.qk_norm = True
xlite_model.mha_q_norm = q_norm
xlite_model.mha_k_norm = k_norm
return xlite_model
def _precompute_freqs_cis(self,
dim: int,
end: int,
dtype: torch.dtype,
theta: float = 10000.0):
freqs = 1.0 / (theta**(torch.arange(
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
cos_cache = freqs.cos().to(dtype)
sin_cache = freqs.sin().to(dtype)
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
return freq_cis.to(device='npu')
class QwenMoeXliteModel(LlamaXliteModel):
def initialize(
self, runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
architecture = vllm_config.model_config.architectures[0]
raise ValueError(
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
dtype = vllm_config.model_config.dtype
config = self._build_model_config(vllm_config)
xlite_model = self._build_model(runnable, vllm_config, config)
rank = torch.distributed.get_rank()
xlite_model.init(config, rank)
freq_cis = super()._precompute_freqs_cis(config.head_dim,
config.max_seq_len, dtype,
config.rope_theta)
return (xlite_model, freq_cis, config.hidden_size, dtype)
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
config = super()._build_model_config(vllm_config)
hf_config = vllm_config.model_config.hf_text_config
ep_group = get_ep_group()
config.n_layers = hf_config.max_window_layers
config.n_dense_layers = 0
config.n_routed_experts = hf_config.num_experts
config.n_shared_experts = 0
config.n_act_experts = hf_config.num_experts_per_tok
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
config.experts_weight_transpose = True # type: ignore
config.moe_intermediate_size = hf_config.moe_intermediate_size
config.norm_topk_prob = hf_config.norm_topk_prob # type: ignore
config.scoring_func = ScoringFuncSoftmax # type: ignore
return config
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
config: ModelConfig) -> Model:
xlite_model = super()._build_model(runnable, vllm_config, config)
layers = runnable.model.layers
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
xlite_model.re_up_gate = [
layer.mlp.experts.w13_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]
xlite_model.re_down = [
layer.mlp.experts.w2_weight[i] for layer in layers
for i in range(layer.mlp.experts.local_num_experts)
]
return xlite_model
def xlite_model_init(
runnable: nn.Module,
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
strategy_map = {
"LlamaForCausalLM": LlamaXliteModel,
"Qwen2ForCausalLM": LlamaXliteModel,
"Qwen3ForCausalLM": LlamaXliteModel,
"Qwen3VLForConditionalGeneration": LlamaXliteModel,
"Qwen3MoeForCausalLM": QwenMoeXliteModel,
}
architecture = vllm_config.model_config.architectures[0]
strategy_class = strategy_map.get(architecture)
if not strategy_class:
raise ValueError(f"{architecture} not supported!")
return strategy_class().initialize(runnable, vllm_config)
class XliteWrapper:
"""
xlite graph wrapper
"""
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
self.runnable = runnable
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
rank = torch.distributed.get_rank()
local_rank = get_world_group().local_rank
self.xlite_rt = Runtime(local_rank, 0, rank,
get_tensor_model_parallel_world_size(),
vllm_config.parallel_config.data_parallel_size)
(self.xlite_model, self.freq_cis, hidden_size,
dtype) = xlite_model_init(runnable, vllm_config)
rt_pool_size = self.xlite_model.get_tensor_pool_size()
if rank == 0:
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
raise ValueError(
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
)
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_states = torch.empty(max_num_tokens,
hidden_size,
device=f"npu:{local_rank}",
dtype=dtype)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"xlite wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def register_kv_caches(self, kv_caches: Any):
self.kv_caches = kv_caches
def __call__(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
list[torch.Tensor]]:
forward_context = get_forward_context()
attn_metadata: Any = forward_context.attn_metadata
if attn_metadata is None:
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is None or not isinstance(attn_metadata,
AscendMetadata):
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if not with_prefill or self.full_mode:
# TODO: When vllm_ascend enables graph mode, attn_metadata.num_decodes
# will be padded in decode requests. Therefore, it is first fixed using
# num_decode_tokens. However, in the future, when MTP is enabled, there
# may be cases where a single request involves multiple tokens, which
# will need to be solved.
num_decodes = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
batch = num_prefills + num_decodes
seq_lens = attn_metadata.seq_lens[:batch]
seq_tensor = torch.cat([
torch.tensor([0]),
torch.tensor(attn_metadata.actual_seq_lengths_q)
],
dim=0)
query_lens = seq_tensor[1:] - seq_tensor[:-1]
query_lens = query_lens[:batch]
cached_lens = seq_lens - query_lens
xlite_attn_metadata = ModelAttnMeta()
xlite_attn_metadata.lens = query_lens.tolist()
xlite_attn_metadata.cached_lens = cached_lens.tolist()
xlite_attn_metadata.is_prefills = [False] * num_decodes + [
True
] * num_prefills
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
).tolist()
h = self.hidden_states[:attn_metadata.num_actual_tokens]
stream = torch.npu.current_stream().npu_stream
if inputs_embeds is None:
self.xlite_model.forward(self.xlite_rt, input_ids,
xlite_attn_metadata, self.kv_caches,
self.freq_cis, h, stream)
else:
self.xlite_model.forward_with_inputs_embeds(
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
self.kv_caches, self.freq_cis, h, stream)
return h
else:
return self.runnable(input_ids, positions, intermediate_tensors,
inputs_embeds)