[PERF]support H2P communication optimization for PanguProMoe (#1463)

### What this PR does / why we need it?
In this PR, we support H2P communication optimization when running
PanguProMoE with dp_size > 1. H2P use `reduce_scatter` and `all_gather`
to replace `all_reduce` to improve performance:

original layer:
input_layernorm --> attn --> tp all_reduce --> post_attention_layernorm
--> dp all_gather --> moe/mlp --> dp reduce_scatter --> tp all_reduce
now:
input_layernorm --> tp all_gather --> attn --> tp reduce_scatter -->
post_attention_layernorm --> all_rank all_gather --> moe/mlp -->
all_rank reduce_scatter

Besides, because `reduce_scatter` requires num_tokens that can be
divided by group size, we need pad the seqs based on
`max_tokens_across_dp`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
This PR has been tested with both offline and online inference using
PanguProMoE-72B.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-06-28 16:10:27 +08:00
committed by GitHub
parent 5c53cbaf2a
commit 8fa188111d

View File

@@ -18,20 +18,26 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_dp_group, get_tp_group,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@@ -47,6 +53,7 @@ from vllm.model_executor.models.utils import (
extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.parallel_state import get_ep_group
@@ -56,6 +63,225 @@ logger = init_logger(__name__)
_ROUTER_SCALE = None
def use_h2p():
# only use H2P when dp_size > 1.
if get_dp_group().world_size > 1:
return True
return False
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
class CustomMergedColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the last dimension.
output_size = sum(output_sizes)
self.output_sizes = output_sizes
self.tp_size = get_world_group().world_size
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
loaded_shard_id: int):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_world_group().rank_in_group
tp_size = get_world_group().world_size
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
# and detach communication to enable customized communication algorithms(e.g., H2P).
class CustomRowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
group=None,
):
# Divide the weight matrix along the first dimension.
self.group = group if group is not None else get_world_group()
self.tp_rank = self.group.rank_in_group
self.tp_size = self.group.world_size
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = self.group.rank_in_group
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
input_parallel = input_
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output = self.quant_method.apply(self, input_parallel, bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class PanguProMoEMLP(nn.Module):
def __init__(
@@ -68,21 +294,39 @@ class PanguProMoEMLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if not use_h2p():
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
else:
self.gate_up_proj = CustomMergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = CustomRowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -207,11 +451,30 @@ class PanguProMoESparseMoeBlock(nn.Module):
router_logits, _ = self.gate(hidden_states)
global _ROUTER_SCALE
_ROUTER_SCALE = self.router_scale
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not use_h2p():
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
# TODO: when using h2p, we have to skip communication in vLLM
# native FusedMoE. here we need to design a better FusedMoE
# (maybe using AscendFusedMoE) to enable these different
# communication schema.
final_hidden_states = self.experts.quant_method(
layer=self.experts,
x=hidden_states,
router_logits=router_logits,
top_k=self.experts.top_k,
renormalize=False,
use_grouped_topk=False,
global_num_experts=self.experts.global_num_experts,
expert_map=self.experts.expert_map,
custom_routing_function=self.experts.custom_routing_function,
apply_router_weight_on_input=self.experts.
apply_router_weight_on_input)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if not use_h2p():
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
@@ -265,13 +528,22 @@ class PanguProMoEAttention(nn.Module):
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
if use_h2p():
self.o_proj = CustomRowParallelLinear(self.total_num_heads *
self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
group=get_tp_group())
else:
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -337,8 +609,7 @@ class PanguProMoEDecoderLayer(nn.Module):
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx
not in mlp_only_layers) and (config.num_experts > 0): ### ???
if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
self.mlp = PanguProMoESparseMoeBlock(
config=config,
quant_config=quant_config,
@@ -364,7 +635,14 @@ class PanguProMoEDecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
h2p_unpad_idx: Optional[torch.Tensor] = None,
h2p_pad_idx: Optional[torch.Tensor] = None,
is_start_layer: Optional[bool] = False,
) -> torch.Tensor:
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
tp_size = get_tp_group().world_size
# Self Attention
if residual is None:
residual = hidden_states
@@ -372,16 +650,64 @@ class PanguProMoEDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if use_h2p():
if is_start_layer:
if need_h2p_pad:
residual = residual.index_select(dim=0, index=h2p_pad_idx)
residual = torch.tensor_split(
residual, tp_size)[get_tp_group().rank_in_group]
else:
if tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if need_h2p_pad:
hidden_states = hidden_states.index_select(
dim=0, index=h2p_unpad_idx)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if use_h2p():
if need_h2p_pad:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_pad_idx)
if tp_size > 1:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_tp_group().device_group)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if use_h2p():
all_rank_group = get_world_group().device_group
output_size = (hidden_states.shape[0] *
get_world_group().world_size,
hidden_states.shape[1])
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
hidden_states,
group=all_rank_group)
hidden_states = output_tensor
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
if use_h2p():
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_world_group().device_group)
return hidden_states, residual
@@ -440,19 +766,61 @@ class PanguProMoEModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if use_h2p():
# calculate necessary padding/unpadding idx before model forward.
# the attn_metadata will be passed directly when use torchair.
# if attn_meatadata is not passed, we try to get it from forward_context.
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
# are same.
max_tokens_across_dp = hidden_states.shape[0]
else:
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
tp_size = get_tp_group().world_size
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
# we need pad it before if the shape can't be divided by group size.
# for h2p, we need pad it so that it can be divided by tp_size.
h2p_padded_len = (
tp_size - (max_tokens_across_dp % tp_size)
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
device=hidden_states.device,
dtype=torch.int32)
h2p_pad_idx = torch.cat([
h2p_unpad_idx,
torch.zeros(h2p_padded_len,
dtype=torch.int32,
device=hidden_states.device)
])
else:
h2p_unpad_idx = None
h2p_pad_idx = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
attn_metadata, h2p_unpad_idx, h2p_pad_idx,
i == self.start_layer)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if use_h2p():
if get_tp_group().world_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_unpad_idx)
return hidden_states