From 8fa188111da3a8f752dc309330d9bd3ec18194e6 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sat, 28 Jun 2025 16:10:27 +0800 Subject: [PATCH] [PERF]support H2P communication optimization for PanguProMoe (#1463) ### What this PR does / why we need it? In this PR, we support H2P communication optimization when running PanguProMoE with dp_size > 1. H2P use `reduce_scatter` and `all_gather` to replace `all_reduce` to improve performance: original layer: input_layernorm --> attn --> tp all_reduce --> post_attention_layernorm --> dp all_gather --> moe/mlp --> dp reduce_scatter --> tp all_reduce now: input_layernorm --> tp all_gather --> attn --> tp reduce_scatter --> post_attention_layernorm --> all_rank all_gather --> moe/mlp --> all_rank reduce_scatter Besides, because `reduce_scatter` requires num_tokens that can be divided by group size, we need pad the seqs based on `max_tokens_across_dp`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR has been tested with both offline and online inference using PanguProMoE-72B. --------- Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/models/pangu_moe.py | 430 +++++++++++++++++++++++++++++--- 1 file changed, 399 insertions(+), 31 deletions(-) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index ff9bb4a..131e1e0 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -18,20 +18,26 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch +import torch.distributed as dist import torch.nn.functional as F from torch import nn +from torch.nn import Parameter from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_dp_group, get_tp_group, + get_world_group) +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -47,6 +53,7 @@ from vllm.model_executor.models.utils import ( extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm_ascend.distributed.parallel_state import get_ep_group @@ -56,6 +63,225 @@ logger = init_logger(__name__) _ROUTER_SCALE = None +def use_h2p(): + # only use H2P when dp_size > 1. + if get_dp_group().world_size > 1: + return True + return False + + +# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). +class CustomMergedColumnParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + output_size = sum(output_sizes) + self.output_sizes = output_sizes + self.tp_size = get_world_group().world_size + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=(self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + loaded_shard_id: int): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + assert loaded_shard_id < len(self.output_sizes) + + tp_rank = get_world_group().rank_in_group + tp_size = get_world_group().world_size + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + +# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. +# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) +# and detach communication to enable customized communication algorithms(e.g., H2P). +class CustomRowParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + group=None, + ): + # Divide the weight matrix along the first dimension. + self.group = group if group is not None else get_world_group() + self.tp_rank = self.group.rank_in_group + self.tp_size = self.group.world_size + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=(self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = self.group.rank_in_group + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + input_parallel = input_ + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.quant_method.apply(self, input_parallel, bias=bias_) + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + class PanguProMoEMLP(nn.Module): def __init__( @@ -68,21 +294,39 @@ class PanguProMoEMLP(nn.Module): prefix: str = "", ) -> None: super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) + if not use_h2p(): + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + else: + self.gate_up_proj = CustomMergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = CustomRowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -207,11 +451,30 @@ class PanguProMoESparseMoeBlock(nn.Module): router_logits, _ = self.gate(hidden_states) global _ROUTER_SCALE _ROUTER_SCALE = self.router_scale - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + if not use_h2p(): + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + else: + # TODO: when using h2p, we have to skip communication in vLLM + # native FusedMoE. here we need to design a better FusedMoE + # (maybe using AscendFusedMoE) to enable these different + # communication schema. + final_hidden_states = self.experts.quant_method( + layer=self.experts, + x=hidden_states, + router_logits=router_logits, + top_k=self.experts.top_k, + renormalize=False, + use_grouped_topk=False, + global_num_experts=self.experts.global_num_experts, + expert_map=self.experts.expert_map, + custom_routing_function=self.experts.custom_routing_function, + apply_router_weight_on_input=self.experts. + apply_router_weight_on_input) + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + if not use_h2p(): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -265,13 +528,22 @@ class PanguProMoEAttention(nn.Module): prefix=f"{prefix}.qkv_proj", ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) + if use_h2p(): + self.o_proj = CustomRowParallelLinear(self.total_num_heads * + self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + group=get_tp_group()) + else: + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -337,8 +609,7 @@ class PanguProMoEDecoderLayer(nn.Module): layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) - if (layer_idx - not in mlp_only_layers) and (config.num_experts > 0): ### ??? + if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): self.mlp = PanguProMoESparseMoeBlock( config=config, quant_config=quant_config, @@ -364,7 +635,14 @@ class PanguProMoEDecoderLayer(nn.Module): residual: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None, + h2p_unpad_idx: Optional[torch.Tensor] = None, + h2p_pad_idx: Optional[torch.Tensor] = None, + is_start_layer: Optional[bool] = False, ) -> torch.Tensor: + need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \ + and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] + tp_size = get_tp_group().world_size + # Self Attention if residual is None: residual = hidden_states @@ -372,16 +650,64 @@ class PanguProMoEDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + if use_h2p(): + if is_start_layer: + if need_h2p_pad: + residual = residual.index_select(dim=0, index=h2p_pad_idx) + residual = torch.tensor_split( + residual, tp_size)[get_tp_group().rank_in_group] + else: + if tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if need_h2p_pad: + hidden_states = hidden_states.index_select( + dim=0, index=h2p_unpad_idx) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) + + if use_h2p(): + if need_h2p_pad: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_pad_idx) + if tp_size > 1: + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_tp_group().device_group) + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states) + + if use_h2p(): + all_rank_group = get_world_group().device_group + output_size = (hidden_states.shape[0] * + get_world_group().world_size, + hidden_states.shape[1]) + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=hidden_states.dtype, + device=hidden_states.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + hidden_states, + group=all_rank_group) + hidden_states = output_tensor + + hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) + + if use_h2p(): + hidden_states = dist._functional_collectives.reduce_scatter_tensor( + hidden_states, + "sum", + scatter_dim=0, + group=get_world_group().device_group) return hidden_states, residual @@ -440,19 +766,61 @@ class PanguProMoEModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + + if use_h2p(): + # calculate necessary padding/unpadding idx before model forward. + + # the attn_metadata will be passed directly when use torchair. + # if attn_meatadata is not passed, we try to get it from forward_context. + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + # when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks + # are same. + max_tokens_across_dp = hidden_states.shape[0] + else: + max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp + + tp_size = get_tp_group().world_size + # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. + # we need pad it before if the shape can't be divided by group size. + # for h2p, we need pad it so that it can be divided by tp_size. + h2p_padded_len = ( + tp_size - (max_tokens_across_dp % tp_size) + ) % tp_size + max_tokens_across_dp - hidden_states.shape[0] + h2p_unpad_idx = torch.arange(hidden_states.shape[0], + device=hidden_states.device, + dtype=torch.int32) + h2p_pad_idx = torch.cat([ + h2p_unpad_idx, + torch.zeros(h2p_padded_len, + dtype=torch.int32, + device=hidden_states.device) + ]) + else: + h2p_unpad_idx = None + h2p_pad_idx = None + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, residual, kv_caches[i - self.start_layer] if kv_caches is not None else None, - attn_metadata) + attn_metadata, h2p_unpad_idx, h2p_pad_idx, + i == self.start_layer) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm(hidden_states, residual) + if use_h2p(): + if get_tp_group().world_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: + hidden_states = hidden_states.index_select(dim=0, + index=h2p_unpad_idx) return hidden_states