[PERF]support H2P communication optimization for PanguProMoe (#1463)

### What this PR does / why we need it?
In this PR, we support H2P communication optimization when running
PanguProMoE with dp_size > 1. H2P use `reduce_scatter` and `all_gather`
to replace `all_reduce` to improve performance:

original layer:
input_layernorm --> attn --> tp all_reduce --> post_attention_layernorm
--> dp all_gather --> moe/mlp --> dp reduce_scatter --> tp all_reduce
now:
input_layernorm --> tp all_gather --> attn --> tp reduce_scatter -->
post_attention_layernorm --> all_rank all_gather --> moe/mlp -->
all_rank reduce_scatter

Besides, because `reduce_scatter` requires num_tokens that can be
divided by group size, we need pad the seqs based on
`max_tokens_across_dp`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
This PR has been tested with both offline and online inference using
PanguProMoE-72B.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-06-28 16:10:27 +08:00
committed by GitHub
parent 5c53cbaf2a
commit 8fa188111d

View File

@@ -18,20 +18,26 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import Parameter
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_dp_group, get_tp_group,
get_world_group)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
@@ -47,6 +53,7 @@ from vllm.model_executor.models.utils import (
extract_layer_index, is_pp_missing_parameter, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
@@ -56,6 +63,225 @@ logger = init_logger(__name__)
_ROUTER_SCALE = None _ROUTER_SCALE = None
def use_h2p():
# only use H2P when dp_size > 1.
if get_dp_group().world_size > 1:
return True
return False
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
class CustomMergedColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the last dimension.
output_size = sum(output_sizes)
self.output_sizes = output_sizes
self.tp_size = get_world_group().world_size
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
loaded_shard_id: int):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_world_group().rank_in_group
tp_size = get_world_group().world_size
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
is_sharded_weight = getattr(param, "is_sharded_weight", False)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
# and detach communication to enable customized communication algorithms(e.g., H2P).
class CustomRowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
group=None,
):
# Divide the weight matrix along the first dimension.
self.group = group if group is not None else get_world_group()
self.tp_rank = self.group.rank_in_group
self.tp_size = self.group.world_size
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = self.group.rank_in_group
input_dim = getattr(param, "input_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param.data
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
input_parallel = input_
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output = self.quant_method.apply(self, input_parallel, bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class PanguProMoEMLP(nn.Module): class PanguProMoEMLP(nn.Module):
def __init__( def __init__(
@@ -68,21 +294,39 @@ class PanguProMoEMLP(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( if not use_h2p():
hidden_size, self.gate_up_proj = MergedColumnParallelLinear(
[intermediate_size] * 2, hidden_size,
bias=False, [intermediate_size] * 2,
quant_config=quant_config, bias=False,
prefix=f"{prefix}.gate_up_proj", quant_config=quant_config,
) prefix=f"{prefix}.gate_up_proj",
self.down_proj = RowParallelLinear( )
intermediate_size, self.down_proj = RowParallelLinear(
hidden_size, intermediate_size,
bias=False, hidden_size,
quant_config=quant_config, bias=False,
reduce_results=reduce_results, quant_config=quant_config,
prefix=f"{prefix}.down_proj", reduce_results=reduce_results,
) prefix=f"{prefix}.down_proj",
)
else:
self.gate_up_proj = CustomMergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = CustomRowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@@ -207,11 +451,30 @@ class PanguProMoESparseMoeBlock(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
global _ROUTER_SCALE global _ROUTER_SCALE
_ROUTER_SCALE = self.router_scale _ROUTER_SCALE = self.router_scale
final_hidden_states = self.experts(hidden_states=hidden_states, if not use_h2p():
router_logits=router_logits) final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
# TODO: when using h2p, we have to skip communication in vLLM
# native FusedMoE. here we need to design a better FusedMoE
# (maybe using AscendFusedMoE) to enable these different
# communication schema.
final_hidden_states = self.experts.quant_method(
layer=self.experts,
x=hidden_states,
router_logits=router_logits,
top_k=self.experts.top_k,
renormalize=False,
use_grouped_topk=False,
global_num_experts=self.experts.global_num_experts,
expert_map=self.experts.expert_map,
custom_routing_function=self.experts.custom_routing_function,
apply_router_weight_on_input=self.experts.
apply_router_weight_on_input)
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if not use_h2p():
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states) final_hidden_states)
@@ -265,13 +528,22 @@ class PanguProMoEAttention(nn.Module):
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( if use_h2p():
self.total_num_heads * self.head_dim, self.o_proj = CustomRowParallelLinear(self.total_num_heads *
hidden_size, self.head_dim,
bias=True, hidden_size,
quant_config=quant_config, bias=True,
prefix=f"{prefix}.o_proj", quant_config=quant_config,
) prefix=f"{prefix}.o_proj",
group=get_tp_group())
else:
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
@@ -337,8 +609,7 @@ class PanguProMoEDecoderLayer(nn.Module):
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers) config.mlp_only_layers)
if (layer_idx if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
not in mlp_only_layers) and (config.num_experts > 0): ### ???
self.mlp = PanguProMoESparseMoeBlock( self.mlp = PanguProMoESparseMoeBlock(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
@@ -364,7 +635,14 @@ class PanguProMoEDecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None, attn_metadata: Optional[AttentionMetadata] = None,
h2p_unpad_idx: Optional[torch.Tensor] = None,
h2p_pad_idx: Optional[torch.Tensor] = None,
is_start_layer: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
tp_size = get_tp_group().world_size
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
@@ -372,16 +650,64 @@ class PanguProMoEDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if use_h2p():
if is_start_layer:
if need_h2p_pad:
residual = residual.index_select(dim=0, index=h2p_pad_idx)
residual = torch.tensor_split(
residual, tp_size)[get_tp_group().rank_in_group]
else:
if tp_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if need_h2p_pad:
hidden_states = hidden_states.index_select(
dim=0, index=h2p_unpad_idx)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
if use_h2p():
if need_h2p_pad:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_pad_idx)
if tp_size > 1:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_tp_group().device_group)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if use_h2p():
all_rank_group = get_world_group().device_group
output_size = (hidden_states.shape[0] *
get_world_group().world_size,
hidden_states.shape[1])
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=hidden_states.dtype,
device=hidden_states.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
hidden_states,
group=all_rank_group)
hidden_states = output_tensor
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
if use_h2p():
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
"sum",
scatter_dim=0,
group=get_world_group().device_group)
return hidden_states, residual return hidden_states, residual
@@ -440,19 +766,61 @@ class PanguProMoEModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
if use_h2p():
# calculate necessary padding/unpadding idx before model forward.
# the attn_metadata will be passed directly when use torchair.
# if attn_meatadata is not passed, we try to get it from forward_context.
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
# are same.
max_tokens_across_dp = hidden_states.shape[0]
else:
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
tp_size = get_tp_group().world_size
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
# we need pad it before if the shape can't be divided by group size.
# for h2p, we need pad it so that it can be divided by tp_size.
h2p_padded_len = (
tp_size - (max_tokens_across_dp % tp_size)
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
device=hidden_states.device,
dtype=torch.int32)
h2p_pad_idx = torch.cat([
h2p_unpad_idx,
torch.zeros(h2p_padded_len,
dtype=torch.int32,
device=hidden_states.device)
])
else:
h2p_unpad_idx = None
h2p_pad_idx = None
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, residual, positions, hidden_states, residual,
kv_caches[i - kv_caches[i -
self.start_layer] if kv_caches is not None else None, self.start_layer] if kv_caches is not None else None,
attn_metadata) attn_metadata, h2p_unpad_idx, h2p_pad_idx,
i == self.start_layer)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
"hidden_states": hidden_states, "hidden_states": hidden_states,
"residual": residual "residual": residual
}) })
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if use_h2p():
if get_tp_group().world_size > 1:
hidden_states = get_tp_group().all_gather(hidden_states, 0)
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
hidden_states = hidden_states.index_select(dim=0,
index=h2p_unpad_idx)
return hidden_states return hidden_states