### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/ops/layer_shard_linear.py`|
|`vllm_ascend/ops/linear.py`|
|`vllm_ascend/ops/linear_op.py`|
|`vllm_ascend/worker/worker.py`|
| ` vllm_ascend/patch/worker/patch_bert.py` |
| ` vllm_ascend/patch/worker/patch_deepseek.py` |
| ` vllm_ascend/patch/worker/patch_distributed.py` |
| ` vllm_ascend/patch/worker/patch_module.py` |
| ` vllm_ascend/patch/worker/patch_multimodal_merge.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next.py` |
| ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` |
| ` vllm_ascend/patch/worker/patch_rejection_sampler.py` |
| ` vllm_ascend/patch/worker/patch_rope.py` |
| ` vllm_ascend/patch/worker/patch_triton.py` |
| ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` |
| ` vllm_ascend/patch/worker/patch_v2_egale.py` |
|` vllm_ascend/worker/npu_input_batch.py`|
|` vllm_ascend/worker/v2/aclgraph_utils.py`|
|` vllm_ascend/worker/v2/attn_utils.py`|
|` vllm_ascend/worker/v2/model_runner.py`|
|` vllm_ascend/worker/v2/sample/gumbel.py`|
|` vllm_ascend/worker/v2/sample/penalties.py`|
|` vllm_ascend/worker/v2/sample/sampler.py`|
|` vllm_ascend/worker/v2/spec_decode/__init__.py`|
|` vllm_ascend/worker/v2/spec_decode/eagle.py`|
|` vllm_ascend/worker/v2/states.py`|
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -14,49 +14,44 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable, Tuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_ep_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_world_group)
|
||||
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size, get_world_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from xlite._C import (AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime, # type: ignore[attr-defined]
|
||||
ScoringFuncSoftmax)
|
||||
from xlite._C import ( # type: ignore[attr-defined]
|
||||
AttnMHA,
|
||||
Model,
|
||||
ModelAttnMeta,
|
||||
ModelConfig,
|
||||
Runtime,
|
||||
ScoringFuncSoftmax,
|
||||
)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata
|
||||
|
||||
|
||||
class XliteModel:
|
||||
|
||||
def initialize(
|
||||
self, runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
raise NotImplementedError(
|
||||
"Xlite Model initialize function not implemented.")
|
||||
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
|
||||
raise NotImplementedError("Xlite Model initialize function not implemented.")
|
||||
|
||||
|
||||
class LlamaXliteModel(XliteModel):
|
||||
|
||||
def initialize(
|
||||
self, runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
|
||||
dtype = vllm_config.model_config.dtype
|
||||
config = self._build_model_config(vllm_config)
|
||||
xlite_model = self._build_model(runnable, vllm_config, config)
|
||||
rank = torch.distributed.get_rank()
|
||||
xlite_model.init(config, rank)
|
||||
|
||||
freq_cis = self._precompute_freqs_cis(config.head_dim,
|
||||
config.max_seq_len, dtype,
|
||||
config.rope_theta)
|
||||
freq_cis = self._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
|
||||
|
||||
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
||||
|
||||
@@ -96,8 +91,7 @@ class LlamaXliteModel(XliteModel):
|
||||
config.block_size = vllm_config.cache_config.block_size
|
||||
return config
|
||||
|
||||
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
|
||||
config: ModelConfig) -> Model:
|
||||
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
|
||||
params_dict = dict(runnable.named_parameters())
|
||||
|
||||
if hasattr(runnable, "language_model"):
|
||||
@@ -108,48 +102,33 @@ class LlamaXliteModel(XliteModel):
|
||||
model_prefix = ""
|
||||
|
||||
xlite_model = Model()
|
||||
xlite_model.embed = params_dict.get(model_prefix +
|
||||
"model.embed_tokens.weight")
|
||||
xlite_model.embed = params_dict.get(model_prefix + "model.embed_tokens.weight")
|
||||
xlite_model.norm = params_dict.get(model_prefix + "model.norm.weight")
|
||||
if vllm_config.model_config.hf_text_config.tie_word_embeddings:
|
||||
xlite_model.head = xlite_model.embed
|
||||
else:
|
||||
xlite_model.head = params_dict.get(model_prefix + "lm_head.weight")
|
||||
xlite_model.attn_norm = [
|
||||
layer.input_layernorm.weight for layer in layers
|
||||
]
|
||||
xlite_model.attn_out = [
|
||||
layer.self_attn.o_proj.weight for layer in layers
|
||||
]
|
||||
xlite_model.mha_qkv = [
|
||||
layer.self_attn.qkv_proj.weight for layer in layers
|
||||
]
|
||||
xlite_model.mlp_norm = [
|
||||
layer.post_attention_layernorm.weight for layer in layers
|
||||
]
|
||||
xlite_model.attn_norm = [layer.input_layernorm.weight for layer in layers]
|
||||
xlite_model.attn_out = [layer.self_attn.o_proj.weight for layer in layers]
|
||||
xlite_model.mha_qkv = [layer.self_attn.qkv_proj.weight for layer in layers]
|
||||
xlite_model.mlp_norm = [layer.post_attention_layernorm.weight for layer in layers]
|
||||
xlite_model.mlp_up_gate = [
|
||||
layer.mlp.gate_up_proj.weight for layer in layers
|
||||
if hasattr(layer.mlp, "gate_up_proj")
|
||||
and layer.mlp.gate_up_proj.weight is not None
|
||||
layer.mlp.gate_up_proj.weight
|
||||
for layer in layers
|
||||
if hasattr(layer.mlp, "gate_up_proj") and layer.mlp.gate_up_proj.weight is not None
|
||||
]
|
||||
xlite_model.mlp_down = [
|
||||
layer.mlp.down_proj.weight for layer in layers
|
||||
if hasattr(layer.mlp, "down_proj")
|
||||
and layer.mlp.down_proj.weight is not None
|
||||
layer.mlp.down_proj.weight
|
||||
for layer in layers
|
||||
if hasattr(layer.mlp, "down_proj") and layer.mlp.down_proj.weight is not None
|
||||
]
|
||||
mha_qkv_bias = [
|
||||
layer.self_attn.qkv_proj.bias for layer in layers
|
||||
if hasattr(layer.self_attn.qkv_proj, "bias")
|
||||
and layer.self_attn.qkv_proj.bias is not None
|
||||
]
|
||||
q_norm = [
|
||||
layer.self_attn.q_norm.weight for layer in layers
|
||||
if hasattr(layer.self_attn, "q_norm")
|
||||
]
|
||||
k_norm = [
|
||||
layer.self_attn.k_norm.weight for layer in layers
|
||||
if hasattr(layer.self_attn, "k_norm")
|
||||
layer.self_attn.qkv_proj.bias
|
||||
for layer in layers
|
||||
if hasattr(layer.self_attn.qkv_proj, "bias") and layer.self_attn.qkv_proj.bias is not None
|
||||
]
|
||||
q_norm = [layer.self_attn.q_norm.weight for layer in layers if hasattr(layer.self_attn, "q_norm")]
|
||||
k_norm = [layer.self_attn.k_norm.weight for layer in layers if hasattr(layer.self_attn, "k_norm")]
|
||||
|
||||
if len(mha_qkv_bias) != config.n_layers:
|
||||
config.qkv_bias = False
|
||||
@@ -157,7 +136,7 @@ class LlamaXliteModel(XliteModel):
|
||||
config.qkv_bias = True
|
||||
xlite_model.mha_qkv_bias = mha_qkv_bias
|
||||
|
||||
if (len(q_norm) != config.n_layers or len(k_norm) != config.n_layers):
|
||||
if len(q_norm) != config.n_layers or len(k_norm) != config.n_layers:
|
||||
config.qk_norm = False
|
||||
else:
|
||||
config.qk_norm = True
|
||||
@@ -166,39 +145,28 @@ class LlamaXliteModel(XliteModel):
|
||||
|
||||
return xlite_model
|
||||
|
||||
def _precompute_freqs_cis(self,
|
||||
dim: int,
|
||||
end: int,
|
||||
dtype: torch.dtype,
|
||||
theta: float = 10000.0):
|
||||
freqs = 1.0 / (theta**(torch.arange(
|
||||
0, dim, 2, dtype=torch.float32, device='cpu')[:(dim // 2)] / dim))
|
||||
def _precompute_freqs_cis(self, dim: int, end: int, dtype: torch.dtype, theta: float = 10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cpu")[: (dim // 2)] / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
cos_cache = freqs.cos().to(dtype)
|
||||
sin_cache = freqs.sin().to(dtype)
|
||||
freq_cis = torch.cat((cos_cache, sin_cache), dim=-1)
|
||||
return freq_cis.to(device='npu')
|
||||
return freq_cis.to(device="npu")
|
||||
|
||||
|
||||
class QwenMoeXliteModel(LlamaXliteModel):
|
||||
|
||||
def initialize(
|
||||
self, runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
def initialize(self, runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
|
||||
architecture = vllm_config.model_config.architectures[0]
|
||||
raise ValueError(
|
||||
f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
|
||||
raise ValueError(f"{architecture} not support VLLM_ASCEND_ENABLE_NZ = 2!")
|
||||
dtype = vllm_config.model_config.dtype
|
||||
config = self._build_model_config(vllm_config)
|
||||
xlite_model = self._build_model(runnable, vllm_config, config)
|
||||
rank = torch.distributed.get_rank()
|
||||
xlite_model.init(config, rank)
|
||||
|
||||
freq_cis = super()._precompute_freqs_cis(config.head_dim,
|
||||
config.max_seq_len, dtype,
|
||||
config.rope_theta)
|
||||
freq_cis = super()._precompute_freqs_cis(config.head_dim, config.max_seq_len, dtype, config.rope_theta)
|
||||
|
||||
return (xlite_model, freq_cis, config.hidden_size, dtype)
|
||||
|
||||
@@ -214,32 +182,27 @@ class QwenMoeXliteModel(LlamaXliteModel):
|
||||
config.def_dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
config.moe_ep_size = ep_group.world_size if vllm_config.parallel_config.enable_expert_parallel else 1
|
||||
config.moe_tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else ep_group.world_size
|
||||
config.experts_weight_transpose = True # type: ignore
|
||||
config.experts_weight_transpose = True # type: ignore
|
||||
config.moe_intermediate_size = hf_config.moe_intermediate_size
|
||||
config.norm_topk_prob = hf_config.norm_topk_prob # type: ignore
|
||||
config.scoring_func = ScoringFuncSoftmax # type: ignore
|
||||
config.norm_topk_prob = hf_config.norm_topk_prob # type: ignore
|
||||
config.scoring_func = ScoringFuncSoftmax # type: ignore
|
||||
return config
|
||||
|
||||
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig,
|
||||
config: ModelConfig) -> Model:
|
||||
def _build_model(self, runnable: nn.Module, vllm_config: VllmConfig, config: ModelConfig) -> Model:
|
||||
xlite_model = super()._build_model(runnable, vllm_config, config)
|
||||
layers = runnable.model.layers
|
||||
xlite_model.gate = [layer.mlp.gate.weight for layer in layers]
|
||||
xlite_model.re_up_gate = [
|
||||
layer.mlp.experts.w13_weight[i] for layer in layers
|
||||
for i in range(layer.mlp.experts.local_num_experts)
|
||||
layer.mlp.experts.w13_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
|
||||
]
|
||||
xlite_model.re_down = [
|
||||
layer.mlp.experts.w2_weight[i] for layer in layers
|
||||
for i in range(layer.mlp.experts.local_num_experts)
|
||||
layer.mlp.experts.w2_weight[i] for layer in layers for i in range(layer.mlp.experts.local_num_experts)
|
||||
]
|
||||
|
||||
return xlite_model
|
||||
|
||||
|
||||
def xlite_model_init(
|
||||
runnable: nn.Module,
|
||||
vllm_config: VllmConfig) -> Tuple[Model, int, int, torch.dtype]:
|
||||
def xlite_model_init(runnable: nn.Module, vllm_config: VllmConfig) -> tuple[Model, int, int, torch.dtype]:
|
||||
strategy_map = {
|
||||
"LlamaForCausalLM": LlamaXliteModel,
|
||||
"Qwen2ForCausalLM": LlamaXliteModel,
|
||||
@@ -266,33 +229,26 @@ class XliteWrapper:
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
local_rank = get_world_group().local_rank
|
||||
self.xlite_rt = Runtime(local_rank, 0, rank,
|
||||
get_tensor_model_parallel_world_size(),
|
||||
vllm_config.parallel_config.data_parallel_size)
|
||||
self.xlite_rt = Runtime(
|
||||
local_rank, 0, rank, get_tensor_model_parallel_world_size(), vllm_config.parallel_config.data_parallel_size
|
||||
)
|
||||
|
||||
(self.xlite_model, self.freq_cis, hidden_size,
|
||||
dtype) = xlite_model_init(runnable, vllm_config)
|
||||
(self.xlite_model, self.freq_cis, hidden_size, dtype) = xlite_model_init(runnable, vllm_config)
|
||||
|
||||
rt_pool_size = self.xlite_model.get_tensor_pool_size()
|
||||
if rank == 0:
|
||||
logger.info(f"xlite runtime pool size: {rt_pool_size} MB")
|
||||
if self.xlite_rt.init_tensor_pool(rt_pool_size) != 0:
|
||||
raise ValueError(
|
||||
f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB"
|
||||
)
|
||||
raise ValueError(f"xlite wrapper init failed! runtime pool size: {rt_pool_size} MB")
|
||||
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_states = torch.empty(max_num_tokens,
|
||||
hidden_size,
|
||||
device=f"npu:{local_rank}",
|
||||
dtype=dtype)
|
||||
self.hidden_states = torch.empty(max_num_tokens, hidden_size, device=f"npu:{local_rank}", dtype=dtype)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"xlite wrapper: {self.runnable}")
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of xlite wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
@@ -307,22 +263,19 @@ class XliteWrapper:
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor,
|
||||
list[torch.Tensor]]:
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: Any = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
attn_metadata = next(iter(attn_metadata.values()), None)
|
||||
if attn_metadata is None or not isinstance(attn_metadata,
|
||||
AscendMetadata):
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
if attn_metadata is None or not isinstance(attn_metadata, AscendMetadata):
|
||||
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
with_prefill = attn_metadata.attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding,
|
||||
]
|
||||
|
||||
if not with_prefill or self.full_mode:
|
||||
@@ -335,11 +288,7 @@ class XliteWrapper:
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
batch = num_prefills + num_decodes
|
||||
seq_lens = attn_metadata.seq_lens[:batch]
|
||||
seq_tensor = torch.cat([
|
||||
torch.tensor([0]),
|
||||
torch.tensor(attn_metadata.actual_seq_lengths_q)
|
||||
],
|
||||
dim=0)
|
||||
seq_tensor = torch.cat([torch.tensor([0]), torch.tensor(attn_metadata.actual_seq_lengths_q)], dim=0)
|
||||
query_lens = seq_tensor[1:] - seq_tensor[:-1]
|
||||
query_lens = query_lens[:batch]
|
||||
cached_lens = seq_lens - query_lens
|
||||
@@ -347,23 +296,19 @@ class XliteWrapper:
|
||||
xlite_attn_metadata = ModelAttnMeta()
|
||||
xlite_attn_metadata.lens = query_lens.tolist()
|
||||
xlite_attn_metadata.cached_lens = cached_lens.tolist()
|
||||
xlite_attn_metadata.is_prefills = [False] * num_decodes + [
|
||||
True
|
||||
] * num_prefills
|
||||
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu(
|
||||
).tolist()
|
||||
xlite_attn_metadata.is_prefills = [False] * num_decodes + [True] * num_prefills
|
||||
xlite_attn_metadata.block_tables = attn_metadata.block_tables.cpu().tolist()
|
||||
|
||||
h = self.hidden_states[:attn_metadata.num_actual_tokens]
|
||||
h = self.hidden_states[: attn_metadata.num_actual_tokens]
|
||||
stream = torch.npu.current_stream().npu_stream
|
||||
if inputs_embeds is None:
|
||||
self.xlite_model.forward(self.xlite_rt, input_ids,
|
||||
xlite_attn_metadata, self.kv_caches,
|
||||
self.freq_cis, h, stream)
|
||||
self.xlite_model.forward(
|
||||
self.xlite_rt, input_ids, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
|
||||
)
|
||||
else:
|
||||
self.xlite_model.forward_with_inputs_embeds(
|
||||
self.xlite_rt, inputs_embeds, xlite_attn_metadata,
|
||||
self.kv_caches, self.freq_cis, h, stream)
|
||||
self.xlite_rt, inputs_embeds, xlite_attn_metadata, self.kv_caches, self.freq_cis, h, stream
|
||||
)
|
||||
return h
|
||||
else:
|
||||
return self.runnable(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return self.runnable(input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
@@ -22,13 +22,13 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class XliteModelRunner(NPUModelRunner):
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model.unwrap()
|
||||
|
||||
def load_model(self) -> None:
|
||||
super().load_model()
|
||||
from vllm_ascend.xlite.xlite import XliteWrapper
|
||||
|
||||
self.model = XliteWrapper(self.model, self.vllm_config)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
Reference in New Issue
Block a user