[PERF]support MERRouter (#1421)
### What this PR does / why we need it? This PR introduces an expert rearrange algorithm for PanguProMoE model. Different from the original grouped topk, it filters out the top experts that are allocated more tokens. Therefore, we can load less experts when calculating gmm. We have test this algorithm for PanguProMoE-72B on 300I Duo platform and 800I A2 platform. On 300I Duo platform, we find that `num_voted_experts` set to 5 achieves both good performance and accuracy. While on 800I A2, we still set it to 8 to use original pangu grouped topk. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -57,6 +57,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -339,41 +340,81 @@ class PanguProMoEMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class PanguProMoESparseMoeBlock(nn.Module):
|
||||
def topk_wrapper(num_voted_experts):
|
||||
|
||||
@staticmethod
|
||||
def pangu_group8_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
renormalize: bool = False,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
global_num_experts: int = 0,
|
||||
):
|
||||
scores = F.softmax(gating_output, dim=1)
|
||||
num_tokens = scores.shape[0]
|
||||
router_scale = _ROUTER_SCALE.squeeze( # type: ignore
|
||||
)
|
||||
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = topk // ep_size
|
||||
router_scale = _ROUTER_SCALE.squeeze() # type: ignore
|
||||
experts_per_group = global_num_experts // topk
|
||||
local_group_start = get_ep_group().rank_in_group * local_num_experts
|
||||
local_group_end = (get_ep_group().rank_in_group +
|
||||
1) * local_num_experts
|
||||
scores = F.softmax(gating_output, dim=1)
|
||||
scores = scores[...,
|
||||
get_ep_group().rank_in_group *
|
||||
local_num_experts:(get_ep_group().rank_in_group + 1) *
|
||||
local_num_experts]
|
||||
scores = scores[..., local_group_start:local_group_end]
|
||||
|
||||
router_weights = router_scale[get_ep_group().rank_in_group *
|
||||
local_num_experts:
|
||||
(get_ep_group().rank_in_group + 1) *
|
||||
local_num_experts]
|
||||
topk_weights, topk_ids = torch.max(scores.view(scores.shape[0],
|
||||
local_num_group, -1),
|
||||
dim=-1)
|
||||
bias = torch.arange(0,
|
||||
local_num_experts,
|
||||
topk,
|
||||
device=scores.device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
topk_ids = topk_ids.to(torch.int32) + bias
|
||||
router_weights = router_scale[local_group_start:local_group_end]
|
||||
|
||||
if num_voted_experts == 8:
|
||||
# use original topk
|
||||
topk_weights, topk_ids = torch.max(scores.view(
|
||||
scores.shape[0], local_num_group, -1),
|
||||
dim=-1)
|
||||
bias = torch.arange(0,
|
||||
local_num_experts,
|
||||
experts_per_group,
|
||||
device=scores.device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
topk_ids = topk_ids.to(torch.int32) + bias
|
||||
|
||||
else:
|
||||
group_expert_indices = torch.arange(experts_per_group,
|
||||
dtype=torch.int32,
|
||||
device=scores.device).view(
|
||||
1, 1, -1)
|
||||
group_expert_offset = (torch.arange(
|
||||
local_num_group, dtype=torch.int32, device=scores.device) *
|
||||
experts_per_group).unsqueeze(0)
|
||||
expert_index_range = torch.arange(experts_per_group,
|
||||
dtype=torch.int32,
|
||||
device=scores.device)
|
||||
|
||||
scores_grouped = scores.view(num_tokens, local_num_group,
|
||||
experts_per_group)
|
||||
best_expert_idx = torch.argmax(scores_grouped,
|
||||
dim=2) # (num_tokens, num_groups)
|
||||
vote_mask = (best_expert_idx.unsqueeze(-1).to(
|
||||
torch.int32) == group_expert_indices)
|
||||
|
||||
expert_vote_freq = vote_mask.sum(dim=0)
|
||||
|
||||
sorted_indices = torch.argsort(expert_vote_freq,
|
||||
dim=1,
|
||||
descending=True).to(torch.int32)
|
||||
topk_experts = sorted_indices[:, :num_voted_experts]
|
||||
keep_mask = ((
|
||||
topk_experts.unsqueeze(-1) == expert_index_range).any(
|
||||
dim=1)).unsqueeze(0)
|
||||
|
||||
masked_scores = torch.where(keep_mask, scores_grouped, 0)
|
||||
|
||||
topk_weights, best_pos_in_group = masked_scores.max(dim=2)
|
||||
best_pos_in_group = best_pos_in_group.to(torch.int32)
|
||||
topk_ids = (best_pos_in_group + group_expert_offset).to(
|
||||
torch.int32)
|
||||
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
router_weights = router_weights.index_select(0, flatten_topk_ids).view(
|
||||
@@ -382,6 +423,11 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
return pangu_group8_topk
|
||||
|
||||
|
||||
class PanguProMoESparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -397,14 +443,15 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.local_num_group = config.num_experts_per_tok // get_ep_group(
|
||||
).world_size
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.local_num_experts = config.num_experts // get_ep_group(
|
||||
).world_size
|
||||
self.router_scale = torch.nn.Parameter(
|
||||
torch.ones((1, self.num_experts)))
|
||||
|
||||
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
|
||||
# good performance without sacrifice too much accuracy. for other platform,
|
||||
# this is set to 8 to use original pangu grouped topk.
|
||||
num_voted_experts = 5 if is_310p() else 8
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
@@ -412,8 +459,7 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
quant_config=quant_config,
|
||||
custom_routing_function=PanguProMoESparseMoeBlock.
|
||||
pangu_group8_topk,
|
||||
custom_routing_function=topk_wrapper(num_voted_experts),
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
|
||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
||||
select_experts)
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
@@ -58,9 +58,9 @@ def forward_oot(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if is_310p():
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_310p(
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_fused_moe_state, npu_stream_switch,
|
||||
get_fused_moe_state, is_310p, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
@@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# Currently, fused_experts on 310p only supports PanguProMoE.
|
||||
def fused_experts_310p(
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -614,8 +613,11 @@ def fused_experts_310p(
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
@@ -628,8 +630,7 @@ def fused_experts_310p(
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(
|
||||
torch.int32) + torch.Tensor([0]).to(torch.int32).npu()
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
Reference in New Issue
Block a user