diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 0dfbf2b..1eb9d2f 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -22,8 +22,6 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`. import os from typing import Dict -import pytest - from tests.e2e.conftest import VllmRunner os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" @@ -155,7 +153,6 @@ def _pangu_torchair_test_fixture( print(f"Generated text: {vllm_output[i][1]!r}") -@pytest.mark.skip("pangu doesn't work, fix me") def test_e2e_pangu_with_torchair(): additional_config = { "torchair_graph_config": { diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index b68bc31..fb526b5 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -65,7 +65,7 @@ class TestTorchairUtils(TestBase): mock_model_registry.return_value = mock_registry utils.register_torchair_model() - self.assertEqual(mock_model_registry.register_model.call_count, 5) + self.assertEqual(mock_model_registry.register_model.call_count, 6) call_args_list = mock_model_registry.register_model.call_args_list expected_registrations = [ @@ -81,7 +81,11 @@ class TestTorchairUtils(TestBase): ("Qwen2ForCausalLM", "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"), ("Qwen3MoeForCausalLM", - "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM") + "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM" + ), + ("PanguProMoEForCausalLM", + "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" + ) ] for i, (expected_name, diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index c0b2a14..3e2148c 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -57,7 +57,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p _ROUTER_SCALE = None @@ -612,9 +611,6 @@ class PanguProMoEAttention(nn.Module): prefix=f"{prefix}.attn", ) - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - def forward( self, positions: torch.Tensor, @@ -625,18 +621,7 @@ class PanguProMoEAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - if self.torchair_graph_enabled: - forward_kwargs = {'trace_flag': False} - output_shape = q.shape - attn_output = torch.empty(output_shape, - dtype=q.dtype, - device=q.device) - forward_kwargs['output'] = attn_output - attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache, - attn_metadata, - **forward_kwargs) - else: - attn_output = self.attn(q, k, v) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 2c11e6d..607991c 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -170,15 +170,6 @@ def fused_experts_moge( local_num_experts = global_num_experts // ep_size local_num_group = top_k // ep_size - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -407,6 +398,7 @@ class AscendFusedMoE(FusedMoE): prefix="", custom_routing_function=None, scoring_func="softmax", + routed_scaling_fator: float = 1.0, e_score_correction_bias=None, apply_router_weight_on_input=False, activation="silu", @@ -414,31 +406,59 @@ class AscendFusedMoE(FusedMoE): num_redundant_experts=0, has_bias=False, ): - super().__init__( - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype, - reduce_results, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - quant_config, - tp_size, - ep_size, - dp_size, - prefix, - custom_routing_function, - scoring_func, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - num_redundant_experts, - has_bias, - ) + if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): + super().__init__( + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype, + reduce_results, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + quant_config, + tp_size, + ep_size, + dp_size, + prefix, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + num_redundant_experts, + has_bias, + ) + else: + super().__init__( + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype, + reduce_results, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + quant_config, + tp_size, + ep_size, + dp_size, + prefix, + custom_routing_function, + scoring_func, + routed_scaling_fator, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + num_redundant_experts, + has_bias, + ) setup_token_dispatchers(self.moe_config.ep_size, top_k=self.top_k, diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py new file mode 100644 index 0000000..eb05760 --- /dev/null +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -0,0 +1,1119 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch_npu +from torch import nn +from torch.nn import Parameter +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (divide, get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group, get_world_group) +from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p + +_ROUTER_SCALE = None + + +def use_h2p(): + # only use H2P when dp_size > 1. + if get_dp_group().world_size > 1: + return True + return False + + +# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). +class CustomMergedColumnParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + output_size = sum(output_sizes) + self.output_sizes = output_sizes + self.tp_size = get_tp_group().world_size + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + loaded_shard_id: int): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + assert loaded_shard_id < len(self.output_sizes) + + tp_rank = get_tp_group().rank_in_group + tp_size = get_tp_group().world_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) +# and detach communication to enable customized communication algorithms(e.g., H2P). +class CustomRowParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + group=None, + ): + # Divide the weight matrix along the first dimension. + self.group = group if group is not None else get_tp_group() + self.tp_rank = self.group.rank_in_group + self.tp_size = self.group.world_size + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = self.group.rank_in_group + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + input_parallel = input_ + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.quant_method.apply(self, input_parallel, bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class PanguProMoEMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + if not use_h2p(): + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + else: + self.gate_up_proj = CustomMergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = CustomRowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def topk_wrapper(num_voted_experts): + + def pangu_group8_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool = False, + num_expert_group: int = 0, + topk_group: int = 0, + global_num_experts: int = 0, + ): + scores = F.softmax(gating_output, dim=1) + num_tokens = scores.shape[0] + router_scale = _ROUTER_SCALE.squeeze( # type: ignore + ) + # TODO: support disable expert parallel + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + local_num_group = topk // ep_size + experts_per_group = global_num_experts // topk + local_group_start = get_ep_group().rank_in_group * local_num_experts + local_group_end = (get_ep_group().rank_in_group + + 1) * local_num_experts + scores = F.softmax(gating_output, dim=1) + scores = scores[..., local_group_start:local_group_end] + + router_weights = router_scale[local_group_start:local_group_end] + + if num_voted_experts == 8: + # use original topk + topk_weights, topk_ids = torch.max(scores.view( + scores.shape[0], local_num_group, -1), + dim=-1) + bias = torch.arange(0, + local_num_experts, + experts_per_group, + device=scores.device, + dtype=torch.int32).unsqueeze(0) + topk_ids = topk_ids.to(torch.int32) + bias + + else: + group_expert_indices = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device).view( + 1, 1, -1) + group_expert_offset = (torch.arange( + local_num_group, dtype=torch.int32, device=scores.device) * + experts_per_group).unsqueeze(0) + expert_index_range = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device) + + scores_grouped = scores.view(num_tokens, local_num_group, + experts_per_group) + best_expert_idx = torch.argmax(scores_grouped, + dim=2) # (num_tokens, num_groups) + vote_mask = (best_expert_idx.unsqueeze(-1).to( + torch.int32) == group_expert_indices) + + expert_vote_freq = vote_mask.sum(dim=0) + + sorted_indices = torch.argsort(expert_vote_freq, + dim=1, + descending=True).to(torch.int32) + topk_experts = sorted_indices[:, :num_voted_experts] + keep_mask = (( + topk_experts.unsqueeze(-1) == expert_index_range).any( + dim=1)).unsqueeze(0) + + masked_scores = torch.where(keep_mask, scores_grouped, 0) + + topk_weights, best_pos_in_group = masked_scores.max(dim=2) + best_pos_in_group = best_pos_in_group.to(torch.int32) + topk_ids = (best_pos_in_group + group_expert_offset).to( + torch.int32) + + flatten_topk_ids = topk_ids.view(-1) + router_weights = router_weights.index_select(0, flatten_topk_ids).view( + topk_ids.shape) + topk_weights *= router_weights + + return topk_weights, topk_ids + + return pangu_group8_topk + + +class PanguProMoESparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.num_experts_per_tok = config.num_experts_per_tok + self.router_scale = torch.nn.Parameter( + torch.ones((1, self.num_experts))) + + # on 300I Duo platform, we find that num_voted_experts set to 5 achieves + # good performance without sacrifice too much accuracy. for other platform, + # this is set to 8 to use original pangu grouped topk. + num_voted_experts = 5 if is_310p() else 8 + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + quant_config=quant_config, + custom_routing_function=topk_wrapper(num_voted_experts), + prefix=f"{prefix}.experts", + ) + self.use_ep = self.experts.use_ep + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None # type: ignore + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + global _ROUTER_SCALE + _ROUTER_SCALE = self.router_scale + + # TODO(angazenn): Does not support MC2 currently + get_forward_context().moe_comm_method_name = "allgathercommimpl" + + if not use_h2p(): + final_hidden_states = self.experts.forward_impl( + hidden_states=hidden_states, router_logits=router_logits) + else: + # TODO: when using h2p, we have to skip communication in vLLM + # native FusedMoE. here we need to design a better FusedMoE + # (maybe using AscendFusedMoE) to enable these different + # communication schema. + final_hidden_states = self.experts.quant_method.apply( + layer=self.experts, + x=hidden_states, + router_logits=router_logits, + top_k=self.experts.top_k, + renormalize=False, + use_grouped_topk=False, + global_num_experts=self.experts.global_num_experts, + expert_map=self.experts.expert_map, + custom_routing_function=self.experts.custom_routing_function, + apply_router_weight_on_input=self.experts. + apply_router_weight_on_input) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if not use_h2p(): + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class PanguProMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + if use_h2p(): + self.o_proj = CustomRowParallelLinear(self.total_num_heads * + self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + group=get_tp_group()) + else: + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + if self.torchair_graph_enabled: + forward_kwargs = {'trace_flag': False} + output_shape = q.shape + attn_output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = attn_output + attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache, + attn_metadata, + **forward_kwargs) + else: + attn_output = self.attn(q, k, v) + + output, _ = self.o_proj(attn_output) + return output + + +class PanguProMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + + self.self_attn = PanguProMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): + self.mlp = PanguProMoESparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = PanguProMoEMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + h2p_unpad_idx: Optional[torch.Tensor] = None, + h2p_pad_idx: Optional[torch.Tensor] = None, + is_start_layer: Optional[bool] = False, + ) -> torch.Tensor: + need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \ + and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] + tp_size = get_tp_group().world_size + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if use_h2p(): + if is_start_layer: + if need_h2p_pad: + residual = residual.index_select(dim=0, index=h2p_pad_idx) + residual = torch.tensor_split( + residual, tp_size)[get_tp_group().rank_in_group] + else: + if tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if need_h2p_pad: + hidden_states = hidden_states.index_select( + dim=0, index=h2p_unpad_idx) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if use_h2p(): + if need_h2p_pad: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_pad_idx) + if tp_size > 1: + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_tp_group().device_group) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if use_h2p(): + all_rank_group = get_world_group().device_group + output_size = (hidden_states.shape[0] * + get_world_group().world_size, + hidden_states.shape[1]) + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=hidden_states.dtype, + device=hidden_states.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + hidden_states, + group=all_rank_group) + hidden_states = output_tensor + + hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) + + if use_h2p(): + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_world_group().device_group) + + return hidden_states, residual + + +@support_torch_compile +class PanguProMoEModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PanguProMoEDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + if use_h2p(): + # calculate necessary padding/unpadding idx before model forward. + + # the attn_metadata will be passed directly when use torchair. + # if attn_meatadata is not passed, we try to get it from forward_context. + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + + max_tokens_across_dp = get_forward_context().max_tokens_across_dp + + tp_size = get_tp_group().world_size + # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. + # we need pad it before if the shape can't be divided by group size. + # for h2p, we need pad it so that it can be divided by tp_size. + h2p_padded_len = ( + tp_size - (max_tokens_across_dp % tp_size) + ) % tp_size + max_tokens_across_dp - hidden_states.shape[0] + h2p_unpad_idx = torch.arange(hidden_states.shape[0], + device=hidden_states.device, + dtype=torch.int32) + h2p_pad_idx = torch.cat([ + h2p_unpad_idx, + torch.zeros(h2p_padded_len, + dtype=torch.int32, + device=hidden_states.device) + ]) + else: + h2p_unpad_idx = None + h2p_pad_idx = None + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, h2p_unpad_idx, h2p_pad_idx, + i == self.start_layer) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + if use_h2p(): + if get_tp_group().world_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_unpad_idx) + return hidden_states + + +class PanguProMoEForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = PanguProMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) # from model + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + # ======================================================= + # BF: add this to load with less layers + if 'layers' in name: + layer_idx = int(name.split('layers.')[-1].split('.')[0]) + if layer_idx >= self.model.end_layer: + continue + + if "rotary_emb.inv_freq" in name: + continue + + if "module" in name: + continue + + if name.endswith('kv_cache_offset'): + continue + + if name.endswith("k_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "k_proj.kv_cache_scale", "attn.key_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if name.endswith("v_proj.kv_cache_scale"): + remapped_kv_scale_name = name.replace( + "v_proj.kv_cache_scale", "attn.value_antiquant_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + loaded_weight = torch.tensor_split(loaded_weight, + tp_size, + dim=0)[tp_rank] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # breakpoint() + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + # breakpoint() + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + if is_310p() and "head" in name: + # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than + # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented + # by linear, we manually cast the format here. + param.data = torch_npu.npu_format_cast(param.data, + ACL_FORMAT_FRACTAL_NZ) + return loaded_params diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 563fda7..13d5879 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -173,6 +173,11 @@ def register_torchair_model(): "Qwen3MoeForCausalLM", "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM") + ModelRegistry.register_model( + "PanguProMoEForCausalLM", + "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" + ) + def torchair_quant_method_register(): from vllm_ascend.quantization.quantizer import \