RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
278 lines
11 KiB
Python
278 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from typing import Any, Callable, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
|
get_world_group)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.sequence import IntermediateTensors
|
|
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
|
AscendMetadata)
|
|
|
|
|
|
class XliteModel:
|
|
|
|
def initialize(
|
|
self, runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
raise NotImplementedError(
|
|
"Xlite Model initialize function not implemented.")
|
|
|
|
|
|
class LlamaXliteModel(XliteModel):
|
|
|
|
def initialize(
|
|
self, runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
dtype = vllm_config.model_config.dtype
|
|
params_dict = dict(runnable.named_parameters())
|
|
layers = runnable.model.layers
|
|
|
|
config = self._build_model_config(vllm_config)
|
|
xlite_model = Model()
|
|
xlite_model.embed = params_dict.get("model.embed_tokens.weight")
|
|
xlite_model.norm = params_dict.get("model.norm.weight")
|
|
if vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
xlite_model.head = xlite_model.embed
|
|
else:
|
|
xlite_model.head = params_dict.get("lm_head.weight")
|
|
xlite_model.attn_norm = [
|
|
layer.input_layernorm.weight for layer in layers
|
|
]
|
|
xlite_model.attn_out = [
|
|
layer.self_attn.o_proj.weight for layer in layers
|
|
]
|
|
xlite_model.mha_qkv = [
|
|
layer.self_attn.qkv_proj.weight for layer in layers
|
|
]
|
|
xlite_model.mlp_norm = [
|
|
layer.post_attention_layernorm.weight for layer in layers
|
|
]
|
|
xlite_model.mlp_up_gate = [
|
|
layer.mlp.gate_up_proj.weight for layer in layers
|
|
]
|
|
xlite_model.mlp_down = [layer.mlp.down_proj.weight for layer in layers]
|
|
mha_qkv_bias = [
|
|
layer.self_attn.qkv_proj.bias for layer in layers
|
|
if hasattr(layer.self_attn.qkv_proj, "bias")
|
|
and layer.self_attn.qkv_proj.bias is not None
|
|
]
|
|
q_norm = [
|
|
layer.self_attn.q_norm.weight for layer in layers
|
|
if hasattr(layer.self_attn, "q_norm")
|
|
]
|
|
k_norm = [
|
|
layer.self_attn.k_norm.weight for layer in layers
|
|
if hasattr(layer.self_attn, "k_norm")
|
|
]
|
|
|
|
if len(mha_qkv_bias) != config.n_layers:
|
|
config.qkv_bias = False
|
|
else:
|
|
config.qkv_bias = True
|
|
xlite_model.mha_qkv_bias = mha_qkv_bias
|
|
|
|
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
|
|
config.qk_norm = False
|
|
else:
|
|
config.qk_norm = True
|
|
xlite_model.mha_q_norm = q_norm
|
|
xlite_model.mha_k_norm = k_norm
|
|
|
|
rank = torch.distributed.get_rank()
|
|
xlite_model.init(config, rank)
|
|
|
|
freq_cis = self._precompute_freqs_cis(config.head_dim,
|
|
config.max_seq_len, dtype,
|
|
config.rope_theta)
|
|
|
|
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
|
|
|
def _build_model_config(self, vllm_config: VllmConfig) -> ModelConfig:
|
|
hf_config = vllm_config.model_config.hf_config
|
|
config = ModelConfig()
|
|
config.vocab_size = hf_config.vocab_size
|
|
config.hidden_size = hf_config.hidden_size
|
|
config.n_layers = hf_config.num_hidden_layers
|
|
config.n_heads = hf_config.num_attention_heads
|
|
config.n_kv_heads = hf_config.num_key_value_heads
|
|
if hasattr(hf_config, "head_dim"):
|
|
config.head_dim = hf_config.head_dim
|
|
else:
|
|
config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads
|
|
config.rope_head_dim = config.head_dim
|
|
config.norm_eps = hf_config.rms_norm_eps
|
|
config.rope_theta = hf_config.rope_theta
|
|
config.softmax_scale = config.head_dim**-0.5
|
|
config.n_dense_layers = hf_config.num_hidden_layers
|
|
config.intermediate_size = hf_config.intermediate_size
|
|
config.def_tp_size = get_tensor_model_parallel_world_size()
|
|
config.def_dp_size = 1
|
|
config.moe_ep_size = 1
|
|
config.moe_tp_size = 1
|
|
|
|
config.attn_type = AttnMHA
|
|
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ
|
|
scheduler_config = vllm_config.scheduler_config
|
|
max_batch_size = scheduler_config.max_num_seqs
|
|
max_seq_len = vllm_config.model_config.max_model_len
|
|
config.max_m = scheduler_config.max_num_batched_tokens
|
|
config.max_batch_size = max_batch_size
|
|
config.max_seq_len = max_seq_len
|
|
config.block_size = vllm_config.cache_config.block_size
|
|
return config
|
|
|
|
def _precompute_freqs_cis(self,
|
|
dim: int,
|
|
end: int,
|
|
dtype: torch.dtype,
|
|
theta: float = 10000.0):
|
|
freqs = 1.0 / (theta**(torch.arange(
|
|
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
|
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
cos_cache = freqs.cos().to(dtype)
|
|
sin_cache = freqs.sin().to(dtype)
|
|
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
|
|
return freq_cis.to(device='npu')
|
|
|
|
|
|
def xlite_model_init(
|
|
runnable: nn.Module,
|
|
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
|
strategy_map = {
|
|
"LlamaForCausalLM": LlamaXliteModel,
|
|
"Qwen2ForCausalLM": LlamaXliteModel,
|
|
"Qwen3ForCausalLM": LlamaXliteModel,
|
|
}
|
|
|
|
architecture = vllm_config.model_config.architectures[0]
|
|
strategy_class = strategy_map.get(architecture)
|
|
if not strategy_class:
|
|
raise ValueError(f"{architecture} not supported!")
|
|
return strategy_class().initialize(runnable, vllm_config)
|
|
|
|
|
|
class XliteWrapper:
|
|
"""
|
|
xlite graph wrapper
|
|
"""
|
|
|
|
def __init__(self, runnable: nn.Module, vllm_config: VllmConfig):
|
|
self.runnable = runnable
|
|
self.full_mode = get_ascend_config().xlite_graph_config.full_mode
|
|
|
|
rank = torch.distributed.get_rank()
|
|
local_rank = get_world_group().local_rank
|
|
self.xlite_rt = Runtime(local_rank, 0, rank,
|
|
get_tensor_model_parallel_world_size())
|
|
|
|
(self.xlite_model, self.freq_cis, hidden_size,
|
|
dtype) = xlite_model_init(runnable, vllm_config)
|
|
|
|
rt_pool_size = self.xlite_model.get_tensor_pool_size()
|
|
if rank == 0:
|
|
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
|
|
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
|
|
raise ValueError(
|
|
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
|
|
)
|
|
|
|
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
self.hidden_states = torch.empty(max_num_tokens,
|
|
hidden_size,
|
|
device=f"npu:{local_rank}",
|
|
dtype=dtype)
|
|
|
|
def __getattr__(self, key: str):
|
|
# allow accessing the attributes of the runnable.
|
|
if hasattr(self.runnable, key):
|
|
return getattr(self.runnable, key)
|
|
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
|
f"xlite wrapper: {self.runnable}")
|
|
|
|
def unwrap(self) -> Callable:
|
|
# in case we need to access the original runnable.
|
|
return self.runnable
|
|
|
|
def register_kv_caches(self, kv_caches: Any):
|
|
self.kv_caches = kv_caches
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
|
|
list[torch.Tensor]]:
|
|
forward_context = get_forward_context()
|
|
attn_metadata: Any = forward_context.attn_metadata
|
|
if attn_metadata is None:
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
|
|
attn_metadata = next(iter(attn_metadata.values()), None)
|
|
if attn_metadata is None or not isinstance(attn_metadata,
|
|
AscendMetadata):
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
|
|
with_prefill = attn_metadata.attn_state not in [
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
]
|
|
|
|
if not with_prefill or self.full_mode:
|
|
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
|
|
seq_lens = attn_metadata.seq_lens[:batch]
|
|
query_lens = attn_metadata.query_start_loc_cpu[
|
|
1:] - attn_metadata.query_start_loc_cpu[:-1]
|
|
query_lens = query_lens[:batch]
|
|
cached_lens = seq_lens - query_lens
|
|
|
|
xlite_attn_metadata = ModelAttnMeta()
|
|
xlite_attn_metadata.lens = query_lens.tolist()
|
|
xlite_attn_metadata.cached_lens = cached_lens.tolist()
|
|
xlite_attn_metadata.is_prefills = [
|
|
False
|
|
] * attn_metadata.num_decodes + [True] * attn_metadata.num_prefills
|
|
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
|
|
).tolist()
|
|
|
|
h = self.hidden_states[:attn_metadata.num_actual_tokens]
|
|
stream = torch.npu.current_stream().npu_stream
|
|
if inputs_embeds is None:
|
|
self.xlite_model.forward(self.xlite_rt, input_ids,
|
|
xlite_attn_metadata, self.kv_caches,
|
|
self.freq_cis, h, stream)
|
|
else:
|
|
self.xlite_model.forward_with_inputs_embeds(
|
|
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
|
|
self.kv_caches, self.freq_cis, h, stream)
|
|
return h
|
|
else:
|
|
return self.runnable(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|