[PERF]support H2P communication optimization for PanguProMoe (#1463)
### What this PR does / why we need it? In this PR, we support H2P communication optimization when running PanguProMoE with dp_size > 1. H2P use `reduce_scatter` and `all_gather` to replace `all_reduce` to improve performance: original layer: input_layernorm --> attn --> tp all_reduce --> post_attention_layernorm --> dp all_gather --> moe/mlp --> dp reduce_scatter --> tp all_reduce now: input_layernorm --> tp all_gather --> attn --> tp reduce_scatter --> post_attention_layernorm --> all_rank all_gather --> moe/mlp --> all_rank reduce_scatter Besides, because `reduce_scatter` requires num_tokens that can be divided by group size, we need pad the seqs based on `max_tokens_across_dp`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR has been tested with both offline and online inference using PanguProMoE-72B. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -18,20 +18,26 @@
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import Parameter
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (divide, get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.distributed.parallel_state import (get_dp_group, get_tp_group,
|
||||||
|
get_world_group)
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@@ -47,6 +53,7 @@ from vllm.model_executor.models.utils import (
|
|||||||
extract_layer_index, is_pp_missing_parameter,
|
extract_layer_index, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||||
@@ -56,6 +63,225 @@ logger = init_logger(__name__)
|
|||||||
_ROUTER_SCALE = None
|
_ROUTER_SCALE = None
|
||||||
|
|
||||||
|
|
||||||
|
def use_h2p():
|
||||||
|
# only use H2P when dp_size > 1.
|
||||||
|
if get_dp_group().world_size > 1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear.
|
||||||
|
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp).
|
||||||
|
class CustomMergedColumnParallelLinear(LinearBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_sizes: list[int],
|
||||||
|
bias: bool = True,
|
||||||
|
gather_output: bool = False,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
output_size = sum(output_sizes)
|
||||||
|
self.output_sizes = output_sizes
|
||||||
|
self.tp_size = get_world_group().world_size
|
||||||
|
self.input_size_per_partition = input_size
|
||||||
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||||
|
self.output_partition_sizes = [self.output_size_per_partition]
|
||||||
|
# If QKV or MergedColumn, use output size of each partition.
|
||||||
|
if hasattr(self, "output_sizes"):
|
||||||
|
self.output_partition_sizes = [
|
||||||
|
divide(output_size, self.tp_size)
|
||||||
|
for output_size in self.output_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(input_size,
|
||||||
|
output_size,
|
||||||
|
skip_bias_add,
|
||||||
|
params_dtype,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
|
self.gather_output = gather_output
|
||||||
|
|
||||||
|
if output_sizes is None:
|
||||||
|
output_sizes = [output_size]
|
||||||
|
|
||||||
|
assert self.quant_method is not None
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
input_size_per_partition=self.input_size_per_partition,
|
||||||
|
output_partition_sizes=self.output_partition_sizes,
|
||||||
|
input_size=self.input_size,
|
||||||
|
output_size=self.output_size,
|
||||||
|
params_dtype=self.params_dtype,
|
||||||
|
weight_loader=(self.weight_loader))
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(
|
||||||
|
torch.empty(self.output_size_per_partition,
|
||||||
|
dtype=params_dtype))
|
||||||
|
set_weight_attrs(self.bias, {
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: int):
|
||||||
|
param_data = param.data
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
|
|
||||||
|
tp_rank = get_world_group().rank_in_group
|
||||||
|
tp_size = get_world_group().world_size
|
||||||
|
if output_dim is not None:
|
||||||
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||||
|
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||||
|
|
||||||
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||||
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
|
shard_size)
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
if not is_sharded_weight:
|
||||||
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
|
shard_size)
|
||||||
|
else:
|
||||||
|
ignore_warning = getattr(param, "ignore_warning", False)
|
||||||
|
if not ignore_warning:
|
||||||
|
logger.warning(
|
||||||
|
"Loading a weight without `output_dim` attribute in "
|
||||||
|
"MergedColumnParallelLinear, assume the weight is "
|
||||||
|
"the same for all partitions.")
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
|
# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear.
|
||||||
|
# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp)
|
||||||
|
# and detach communication to enable customized communication algorithms(e.g., H2P).
|
||||||
|
class CustomRowParallelLinear(LinearBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
input_is_parallel: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
reduce_results: bool = True,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
group=None,
|
||||||
|
):
|
||||||
|
# Divide the weight matrix along the first dimension.
|
||||||
|
self.group = group if group is not None else get_world_group()
|
||||||
|
self.tp_rank = self.group.rank_in_group
|
||||||
|
self.tp_size = self.group.world_size
|
||||||
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
|
self.output_size_per_partition = output_size
|
||||||
|
self.output_partition_sizes = [output_size]
|
||||||
|
|
||||||
|
super().__init__(input_size,
|
||||||
|
output_size,
|
||||||
|
skip_bias_add,
|
||||||
|
params_dtype,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
|
self.input_is_parallel = input_is_parallel
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
|
||||||
|
assert self.quant_method is not None
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
input_size_per_partition=self.input_size_per_partition,
|
||||||
|
output_partition_sizes=self.output_partition_sizes,
|
||||||
|
input_size=self.input_size,
|
||||||
|
output_size=self.output_size,
|
||||||
|
params_dtype=self.params_dtype,
|
||||||
|
weight_loader=(self.weight_loader))
|
||||||
|
if not reduce_results and (bias and not skip_bias_add):
|
||||||
|
raise ValueError("When not reduce the results, adding bias to the "
|
||||||
|
"results can lead to incorrect results")
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(
|
||||||
|
torch.empty(self.output_size, dtype=params_dtype))
|
||||||
|
set_weight_attrs(self.bias, {
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
tp_rank = self.group.rank_in_group
|
||||||
|
input_dim = getattr(param, "input_dim", None)
|
||||||
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||||
|
is_sharded_weight = is_sharded_weight
|
||||||
|
|
||||||
|
param_data = param.data
|
||||||
|
if input_dim is not None and not is_sharded_weight:
|
||||||
|
shard_size = param_data.shape[input_dim]
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||||
|
shard_size)
|
||||||
|
|
||||||
|
# Special case for loading scales off disk, which often do not
|
||||||
|
# have a shape (such as in the case of AutoFP8).
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
|
input_parallel = input_
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in TP>1 case)
|
||||||
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
output = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
|
|
||||||
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
|
||||||
class PanguProMoEMLP(nn.Module):
|
class PanguProMoEMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -68,21 +294,39 @@ class PanguProMoEMLP(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
if not use_h2p():
|
||||||
hidden_size,
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
[intermediate_size] * 2,
|
hidden_size,
|
||||||
bias=False,
|
[intermediate_size] * 2,
|
||||||
quant_config=quant_config,
|
bias=False,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
quant_config=quant_config,
|
||||||
)
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
self.down_proj = RowParallelLinear(
|
)
|
||||||
intermediate_size,
|
self.down_proj = RowParallelLinear(
|
||||||
hidden_size,
|
intermediate_size,
|
||||||
bias=False,
|
hidden_size,
|
||||||
quant_config=quant_config,
|
bias=False,
|
||||||
reduce_results=reduce_results,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.down_proj",
|
reduce_results=reduce_results,
|
||||||
)
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gate_up_proj = CustomMergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = CustomRowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@@ -207,11 +451,30 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
|||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
global _ROUTER_SCALE
|
global _ROUTER_SCALE
|
||||||
_ROUTER_SCALE = self.router_scale
|
_ROUTER_SCALE = self.router_scale
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
if not use_h2p():
|
||||||
router_logits=router_logits)
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits)
|
||||||
|
else:
|
||||||
|
# TODO: when using h2p, we have to skip communication in vLLM
|
||||||
|
# native FusedMoE. here we need to design a better FusedMoE
|
||||||
|
# (maybe using AscendFusedMoE) to enable these different
|
||||||
|
# communication schema.
|
||||||
|
final_hidden_states = self.experts.quant_method(
|
||||||
|
layer=self.experts,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.experts.top_k,
|
||||||
|
renormalize=False,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
global_num_experts=self.experts.global_num_experts,
|
||||||
|
expert_map=self.experts.expert_map,
|
||||||
|
custom_routing_function=self.experts.custom_routing_function,
|
||||||
|
apply_router_weight_on_input=self.experts.
|
||||||
|
apply_router_weight_on_input)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if not use_h2p():
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
@@ -265,13 +528,22 @@ class PanguProMoEAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = RowParallelLinear(
|
if use_h2p():
|
||||||
self.total_num_heads * self.head_dim,
|
self.o_proj = CustomRowParallelLinear(self.total_num_heads *
|
||||||
hidden_size,
|
self.head_dim,
|
||||||
bias=True,
|
hidden_size,
|
||||||
quant_config=quant_config,
|
bias=True,
|
||||||
prefix=f"{prefix}.o_proj",
|
quant_config=quant_config,
|
||||||
)
|
prefix=f"{prefix}.o_proj",
|
||||||
|
group=get_tp_group())
|
||||||
|
else:
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -337,8 +609,7 @@ class PanguProMoEDecoderLayer(nn.Module):
|
|||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||||
config.mlp_only_layers)
|
config.mlp_only_layers)
|
||||||
if (layer_idx
|
if (layer_idx not in mlp_only_layers) and (config.num_experts > 0):
|
||||||
not in mlp_only_layers) and (config.num_experts > 0): ### ???
|
|
||||||
self.mlp = PanguProMoESparseMoeBlock(
|
self.mlp = PanguProMoESparseMoeBlock(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -364,7 +635,14 @@ class PanguProMoEDecoderLayer(nn.Module):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
|
h2p_unpad_idx: Optional[torch.Tensor] = None,
|
||||||
|
h2p_pad_idx: Optional[torch.Tensor] = None,
|
||||||
|
is_start_layer: Optional[bool] = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \
|
||||||
|
and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]
|
||||||
|
tp_size = get_tp_group().world_size
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -372,16 +650,64 @@ class PanguProMoEDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
if use_h2p():
|
||||||
|
if is_start_layer:
|
||||||
|
if need_h2p_pad:
|
||||||
|
residual = residual.index_select(dim=0, index=h2p_pad_idx)
|
||||||
|
residual = torch.tensor_split(
|
||||||
|
residual, tp_size)[get_tp_group().rank_in_group]
|
||||||
|
else:
|
||||||
|
if tp_size > 1:
|
||||||
|
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||||
|
if need_h2p_pad:
|
||||||
|
hidden_states = hidden_states.index_select(
|
||||||
|
dim=0, index=h2p_unpad_idx)
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_h2p():
|
||||||
|
if need_h2p_pad:
|
||||||
|
hidden_states = hidden_states.index_select(dim=0,
|
||||||
|
index=h2p_pad_idx)
|
||||||
|
if tp_size > 1:
|
||||||
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
hidden_states,
|
||||||
|
"sum",
|
||||||
|
scatter_dim=0,
|
||||||
|
group=get_tp_group().device_group)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
|
if use_h2p():
|
||||||
|
all_rank_group = get_world_group().device_group
|
||||||
|
output_size = (hidden_states.shape[0] *
|
||||||
|
get_world_group().world_size,
|
||||||
|
hidden_states.shape[1])
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty(output_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
# All-gather.
|
||||||
|
dist.all_gather_into_tensor(output_tensor,
|
||||||
|
hidden_states,
|
||||||
|
group=all_rank_group)
|
||||||
|
hidden_states = output_tensor
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
if use_h2p():
|
||||||
|
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||||
|
hidden_states,
|
||||||
|
"sum",
|
||||||
|
scatter_dim=0,
|
||||||
|
group=get_world_group().device_group)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -440,19 +766,61 @@ class PanguProMoEModel(nn.Module):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
if use_h2p():
|
||||||
|
# calculate necessary padding/unpadding idx before model forward.
|
||||||
|
|
||||||
|
# the attn_metadata will be passed directly when use torchair.
|
||||||
|
# if attn_meatadata is not passed, we try to get it from forward_context.
|
||||||
|
if attn_metadata is None:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
|
||||||
|
# are same.
|
||||||
|
max_tokens_across_dp = hidden_states.shape[0]
|
||||||
|
else:
|
||||||
|
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||||
|
|
||||||
|
tp_size = get_tp_group().world_size
|
||||||
|
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.
|
||||||
|
# we need pad it before if the shape can't be divided by group size.
|
||||||
|
# for h2p, we need pad it so that it can be divided by tp_size.
|
||||||
|
h2p_padded_len = (
|
||||||
|
tp_size - (max_tokens_across_dp % tp_size)
|
||||||
|
) % tp_size + max_tokens_across_dp - hidden_states.shape[0]
|
||||||
|
h2p_unpad_idx = torch.arange(hidden_states.shape[0],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
h2p_pad_idx = torch.cat([
|
||||||
|
h2p_unpad_idx,
|
||||||
|
torch.zeros(h2p_padded_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=hidden_states.device)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
h2p_unpad_idx = None
|
||||||
|
h2p_pad_idx = None
|
||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, residual,
|
positions, hidden_states, residual,
|
||||||
kv_caches[i -
|
kv_caches[i -
|
||||||
self.start_layer] if kv_caches is not None else None,
|
self.start_layer] if kv_caches is not None else None,
|
||||||
attn_metadata)
|
attn_metadata, h2p_unpad_idx, h2p_pad_idx,
|
||||||
|
i == self.start_layer)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
"residual": residual
|
||||||
})
|
})
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
if use_h2p():
|
||||||
|
if get_tp_group().world_size > 1:
|
||||||
|
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||||
|
if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]:
|
||||||
|
hidden_states = hidden_states.index_select(dim=0,
|
||||||
|
index=h2p_unpad_idx)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user