v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

View File

View File

@@ -0,0 +1,364 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from collections.abc import Iterable
from typing import Any, List, Optional, Union
import torch
import torch.nn.functional as F
import vllm
import vllm.envs as envs
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
def all_gather_and_maybe_unpad(
hidden_states: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
if pad_size > 0:
return hidden_states[:-pad_size, :]
return hidden_states
def maybe_pad_and_reduce_scatter(
hidden_states: torch.Tensor,
pad_size: int,
) -> torch.Tensor:
if pad_size > 0:
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
return hidden_states
class CustomQwen2Attention(Qwen2Attention):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=prefix,
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
q, k = self.rotary_emb(positions,
q,
k,
is_prefill=False,
is_qwen_torchair=True)
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = q.shape
output = torch.empty(output_shape,
dtype=q.dtype,
device=q.device)
forward_kwargs['output'] = output
attn_output = self.attn.impl.forward(self.attn,
q,
k,
v,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
trace_flag=False,
**forward_kwargs)
output, _ = self.o_proj(attn_output)
return output
else:
if type(self.rotary_emb) is RotaryEmbedding:
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class CustomQwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = CustomQwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen2Model(Qwen2Model):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=decoder_layer_type)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = kv_caches[i - self.start_layer] \
if kv_caches is not None else None
hidden_states, residual = layer(positions,
hidden_states,
residual,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen2Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = CustomQwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM

View File

@@ -0,0 +1,537 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Any, List, Optional, Union
import torch
import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)
return hidden_states
class CustomQwen3MoeAttention(Qwen3MoeAttention):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@staticmethod
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
head_dim: int, q_norm, k_norm):
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
q_by_head = q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
k_by_head = k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k, v
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
self.head_dim, self.q_norm, self.k_norm)
if (self.torchair_graph_enabled and attn_metadata is not None and
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
q, k = self.rotary_emb(positions,
q,
k,
is_prefill=False,
is_qwen_torchair=True)
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = q.shape
output = torch.empty(output_shape,
dtype=q.dtype,
device=q.device)
forward_kwargs['output'] = output
attn_output = self.attn.impl.forward(self.attn,
q,
k,
v,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
trace_flag=False,
**forward_kwargs)
output, _ = self.o_proj(attn_output)
return output
else:
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = CustomQwen3MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
if not self.use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.num_redundant_experts = parallel_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CustomQwen3MoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
SupportsPP.__init__(self)
SupportsLoRA.__init__(self)
MixtureOfExperts.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -0,0 +1,218 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/deepseek_mtp.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
TorchairDeepseekV2DecoderLayer
class TorchairDeepSeekShareHead(SharedHead):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = TorchairDeepSeekShareHead(config=config,
quant_config=quant_config,
prefix=maybe_prefix(
prefix,
"shared_head"))
self.mtp_block = TorchairDeepseekV2DecoderLayer(
config, prefix, model_config, cache_config, quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds),
inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
TorchairDeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
# Note: torch._dynamo.exc.Unsupported: builtin: str
self.layers_list = [
self.layers[str(idx)]
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
step_kv_cache = kv_caches[
current_step_idx] if kv_caches is not None else None
return self.layers_list[current_step_idx](
input_ids,
positions,
step_kv_cache,
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits
class TorchairDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.model_config.hf_config
self.model = TorchairDeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
TorchairDeepseekV2ForCausalLM
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
pass

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,372 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_custom_op, is_310p
def custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
)
def rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config(
).torchair_graph_config.enabled and not is_qwen_torchair:
return self.forward_native(
positions,
query,
key,
offsets,
)
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2, 2).transpose(3,
2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Find dim range bounds based on rotations
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = torch.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
def yarn_linear_ramp_mask(min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype)
cache = torch.cat([freqs.cos() * self.mscale,
freqs.sin() * self.mscale],
dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def __set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding
_original_re_init = RotaryEmbedding.__init__
def qwen_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
__set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)
def rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(max_seq_len,
self.max_position_embeddings):
__set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)
# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)
query = query.unsqueeze(1)
key = key.unsqueeze(1)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore
def deepseek_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")

View File

@@ -0,0 +1,29 @@
from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
TorchairAscendW4A8DynamicFusedMoEMethod,
TorchairAscendW4A8DynamicLinearMethod)
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
TorchairAscendW8A8DynamicFusedMoEMethod,
TorchairAscendW8A8DynamicLinearMethod)
class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer):
@staticmethod
def build_linear_method():
return TorchairAscendW8A8DynamicLinearMethod()
@staticmethod
def build_moe_method():
return TorchairAscendW8A8DynamicFusedMoEMethod()
class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer):
@staticmethod
def build_linear_method():
return TorchairAscendW4A8DynamicLinearMethod()
@staticmethod
def build_moe_method():
return TorchairAscendW4A8DynamicFusedMoEMethod()

View File

@@ -0,0 +1,439 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import numpy as np
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
class TorchairAscendW4A8DynamicLinearMethod:
"""Linear method for Ascend W4A8_DYNAMIC
"""
def __init__(self):
self.transpose_weight = True
try:
self.group_size = get_current_vllm_config(
).quant_config.quant_description.get("group_size", 256)
except AttributeError:
self.group_size = 256
@staticmethod
def get_weight(input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_pergroup_param(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
per_group_scale: torch.Tensor):
k, n = weight.shape
group_num, n = per_group_scale.shape
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
return antiquant_scale.npu(), bias
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul(
x,
layer.weight,
antiquant_scale=layer.weight_scale_second.to(x.dtype),
antiquant_group_size=self.group_size,
)
def process_weights_after_loading(self, layer: torch.nn.Module):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data,
layer.weight_scale.data,
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
)
param = torch.nn.Parameter(scale_bias, requires_grad=False)
layer.register_parameter("weight_scale_bias", param)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
class TorchairAscendW4A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W4A8_DYNAMIC.
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0"
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
if self.new_quant_version and self.tp_size > 16:
raise ValueError(
"The current weight does not support moe part tp>16.")
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
if self.new_quant_version:
w13_output_size = intermediate_size_per_partition
w2_output_size = hidden_sizes // 2
else:
w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts,
w13_output_size,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_scale_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_offset_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=params_dtype)
if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_scale_bias"] = torch.empty(num_experts,
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = torchair_select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
if fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second,
w2_scale=layer.w2_weight_scale_second,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into layers module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return torchair_fused_experts_with_all2all(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second,
w2_scale=layer.w2_weight_scale_second,
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
group_num, k, n = weight.shape
# the weight of the new version is reduced by half by pack n, so it needs to be restored
if self.new_quant_version:
n = n * 2
per_group_scale = per_group_scale.reshape(group_num, -1, n)
group_num, quantgroup_num, n = per_group_scale.shape
bias = None
if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
layer.register_parameter("w2_scale_bias", w2_scale_bias)
def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version:
group_num, k, n = weight.shape
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
packed_n = n // 4
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
packed_weight = torch.from_numpy(
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
return packed_weight.reshape(group_num, k, packed_n).npu()
else:
return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None,
torch.quint4x2, -1, False)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data,
layer.w13_weight_scale_second.data)
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data,
layer.w2_weight_scale_second.data)
self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,452 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d)
class AscendAttentionTorchairBackend(AscendAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_TORCHAIR"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
return AscendAttentionTorchairBackendImpl
@staticmethod
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
return AscendTorchairMetadata
@staticmethod
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
return AscendAttentionTorchairMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@dataclass
class AscendDecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendTorchairMetadata(AscendMetadata):
decode: Optional[AscendDecodeMetadata] = None
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(vllm_config, device)
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
self.vllm_config.cache_config.block_size)
self.max_blocks = (self.model_config.max_model_len +
self.vllm_config.cache_config.block_size -
1) // self.vllm_config.cache_config.block_size
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_blocks = self.max_blocks
graph_block_tables = torch.zeros((num_seqs, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
graph_block_tables[:num_seqs, :
num_blocks] = block_tables[:num_seqs, :
num_blocks]
else:
graph_block_tables[:num_seqs, :
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables[:, :max_blocks]
def build_torchair_graph_dummy(
self, common_attn_metadata: TorchairCommonAttentionMetadata
) -> AscendTorchairMetadata:
device = self.device
num_reqs = common_attn_metadata.num_reqs
block_table = torch.zeros((num_reqs, self.max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
attn_metadata = AscendTorchairMetadata(
num_actual_tokens=common_attn_metadata.num_actual_tokens,
block_tables=block_table,
query_lens=0,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly,
decode=decode_metadata)
return attn_metadata
def build(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_table[:num_reqs])
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
decode_metadata = None
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size > -1
if common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
]:
max_seq_lens = seq_lens.max().item()
num_seqs = len(seq_lens)
if use_torchair_graph and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
]:
num_reqs_pad_size = 0
num_token_pad_size = 0
if graph_pad_size != 0:
pad_value = 0
num_token_pad_size = graph_pad_size - num_actual_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * num_reqs_pad_size
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, padding])
block_table_padding = torch.zeros(
(num_reqs_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs + num_reqs_pad_size, block_table)
padding_0 = torch.zeros(num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
decode_metadata = AscendDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens,
attn_mask=None)
attn_metadata = AscendTorchairMetadata(
decode=decode_metadata,
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=common_attn_metadata.max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
return attn_metadata
class AscendAttentionTorchairBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendTorchairMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = False,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0
and kv_cache[0].numel() > 0
and kv_cache[0].dtype == torch.int8)
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
output = layer.quant_method.apply(layer, query, key, value,
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)
return output.view(num_tokens, self.hidden_size)
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
output = output.view(-1, self.num_heads, self.head_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"AscendAttentionTorchairBackendImpl")
if kv_cache is not None and kv_cache[0].numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
block_size = self.scale_tensor + key_cache.shape[1]
slots_indices = slots.reshape(-1, 1)
block_indices = slots_indices // block_size
slots_indices = slots_indices % block_size
indices = torch.cat((block_indices, slots_indices), dim=1)
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if is_310p():
# align q k v output tensors
query = aligned_16(query)
key = aligned_16(key)
value = aligned_16(value)
output = aligned_16(output)
# do reformat in case of broadcasted tensors
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
output = output[:num_tokens, :, :]
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
decode_meta = attn_metadata.decode
assert decode_meta is not None
seq_lens = decode_meta.seq_lens_list
block_table = decode_meta.block_table
block_size = key_cache.shape[1]
query = query.view(num_tokens, 1,
self.num_heads * self.head_size).contiguous()
output = torch_npu.npu_incre_flash_attention(
query,
key_cache,
value_cache,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
actual_seq_lengths=seq_lens,
scale_value=self.scale,
block_table=block_table,
input_layout='BSH',
block_size=block_size)
else:
raise NotImplementedError(
"Torchair graph mode with non-MLA attention backend is still experimental."
"v1 scheduler(chunked prefill) is not supported at this moment. Please"
"setting 'ascend_scheduler_config':{'enabled':true} in additional_config"
"to use ascend scheduler.")
return output.view(num_tokens, self.hidden_size)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,446 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# isort: skip_file
import types
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.torchair.utils import (
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
check_torchair_cache_exist, converting_weight_acl_format,
register_torchair_model, torchair_ops_patch,
torchair_quant_method_register, write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_310p)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class NPUTorchairModelRunner(NPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)
ascend_config = get_ascend_config()
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self._check_batch_sizes_consistency()
register_torchair_model()
torchair_ops_patch()
torchair_quant_method_register()
def _sync_metadata_across_dp(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.dp_size == 1:
if not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
dtype=torch.int32,
device="npu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
dist.all_reduce(num_tokens_across_dp,
group=get_dp_group().device_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
max_num_token)
num_tokens_across_dp = torch.full((self.dp_size, ),
maybe_padded_num_tokens,
dtype=torch.int32,
device="npu")
else:
maybe_padded_num_tokens = num_tokens
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if not with_prefill:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
num_actual_tokens=1,
actual_seq_lengths_q=self.actual_seq_lengths_q,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
common_attn_metadata)
else:
attn_metadata = super()._build_attention_metadata(
with_prefill, num_reqs, skip_attn)
return attn_metadata
def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):
if not with_prefill:
# Only mark static while compiling
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
if self.speculative_config:
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
for kv in self.kv_caches:
assert isinstance(kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
if is_310p():
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
model_kwargs = {}
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
if is_310p():
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = super()._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
return hidden_states
def _convert_torch_format(self, kv_cache):
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
def _capture_model(self):
"""Override from NPUModelRunner to use torchair graph capture."""
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist (either
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
# different), we will compile the model twice. The first time is
# used to generate the cache, and the second time is used to load the
# cache to skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
" torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
# runtime errors caused by configuration mismatches in graph mode.
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
def _use_aclgraph(self) -> bool:
return False
def _check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
if not with_prefill:
self.graph_pad_size = graph_pad_size
else:
super()._update_graph_pad_size(with_prefill, graph_pad_size)
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
"""Override from NPUModelRunner to update input_ids and positions"""
input_ids, positions = super()._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
if not with_prefill:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
return input_ids, positions
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
model_kwargs = {
"kv_caches": self.kv_caches,
"attn_metadata": attn_metadata
}
if not with_prefill:
if is_310p():
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
if is_310p():
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return hidden_states
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
if get_ascend_config().torchair_graph_config.mode:
config.mode = get_ascend_config().torchair_graph_config.mode
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def _build_drafter_prepare_inputs_torchair_param(self):
return True
def get_dp_padding(self, num_tokens):
"""Override from NPUModelRunner to get dp padding"""
return 0, None

View File

@@ -0,0 +1,63 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file)
from vllm_ascend.worker.worker_v1 import NPUWorker
class NPUTorchairWorker(NPUWorker):
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
def determine_available_memory(self) -> int:
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
available_kv_cache_memory = super().determine_available_memory()
if get_ascend_config(
).torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist(
):
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def init_device(self):
"""Override init_device to init torchair model runner"""
device = self._init_device()
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)

View File

@@ -0,0 +1,205 @@
import fcntl
import os
import shutil
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
import torch
import torch_npu
try:
# Recent release of torchair has moved these ops to `.scope`.
from torchair.scope import npu_stream_switch as _npu_stream_switch
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
TORCHAIR_CACHE_DIR = os.path.join(
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
@dataclass
class TorchairCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
decode_token_per_req: int
actual_seq_lengths_q: list[int]
attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
graph_pad_size: int = -1
@contextmanager
def _file_lock(file_descriptor, lock_type):
fcntl.flock(file_descriptor, lock_type)
try:
yield
finally:
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
def _get_torchair_current_work_dir(file_name=None):
if file_name is None:
return TORCHAIR_CACHE_DIR
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
def check_torchair_cache_exist():
res = False
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res
def check_kv_cache_bytes_cache_exist():
res = False
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
if os.path.exists(kv_cache_bytes_cache_abs_path):
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
if len(file_list) != 0:
res = True
return res
def read_kv_cache_bytes_from_file(rank) -> int:
kv_cache_bytes = -1
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_SH):
kv_cache_bytes = int(f.readline())
return kv_cache_bytes
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_EX):
f.write(f"{kv_cache_bytes}")
def delete_torchair_cache_file():
torch_air_abs_path = _get_torchair_current_work_dir()
try:
shutil.rmtree(torch_air_abs_path)
except FileNotFoundError:
pass
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,
enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self
def converting_weight_acl_format(model, format):
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
# conversion when using torchair graph mode on 300I Duo platform.
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
for module in model.modules():
if isinstance(module, FusedMoE):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
module.w2_weight.data = torch_npu.npu_format_cast(
module.w2_weight.data, format)
def register_torchair_model():
from vllm import ModelRegistry
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
)
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
def torchair_quant_method_register():
from vllm_ascend.quantization.quantizer import \
SUPPORT_ASCEND_QUANTIZER_TYPE
from vllm_ascend.torchair.quantization.torchair_quantizer import (
TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer)
SUPPORT_ASCEND_QUANTIZER_TYPE[
"W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer
SUPPORT_ASCEND_QUANTIZER_TYPE[
"W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer
def torchair_ops_patch():
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
deepseek_rope_init_func, native_rope_deepseek_forward,
qwen_rope_init_func, rope_forward)
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]