diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 131e1e0..644a00e 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -57,6 +57,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.utils import is_310p logger = init_logger(__name__) @@ -339,41 +340,81 @@ class PanguProMoEMLP(nn.Module): return x -class PanguProMoESparseMoeBlock(nn.Module): +def topk_wrapper(num_voted_experts): - @staticmethod def pangu_group8_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + renormalize: bool = False, num_expert_group: int = 0, topk_group: int = 0, global_num_experts: int = 0, ): + scores = F.softmax(gating_output, dim=1) + num_tokens = scores.shape[0] + router_scale = _ROUTER_SCALE.squeeze( # type: ignore + ) + ep_size = get_ep_group().world_size local_num_experts = global_num_experts // ep_size local_num_group = topk // ep_size - router_scale = _ROUTER_SCALE.squeeze() # type: ignore + experts_per_group = global_num_experts // topk + local_group_start = get_ep_group().rank_in_group * local_num_experts + local_group_end = (get_ep_group().rank_in_group + + 1) * local_num_experts scores = F.softmax(gating_output, dim=1) - scores = scores[..., - get_ep_group().rank_in_group * - local_num_experts:(get_ep_group().rank_in_group + 1) * - local_num_experts] + scores = scores[..., local_group_start:local_group_end] - router_weights = router_scale[get_ep_group().rank_in_group * - local_num_experts: - (get_ep_group().rank_in_group + 1) * - local_num_experts] - topk_weights, topk_ids = torch.max(scores.view(scores.shape[0], - local_num_group, -1), - dim=-1) - bias = torch.arange(0, - local_num_experts, - topk, - device=scores.device, - dtype=torch.int32).unsqueeze(0) - topk_ids = topk_ids.to(torch.int32) + bias + router_weights = router_scale[local_group_start:local_group_end] + + if num_voted_experts == 8: + # use original topk + topk_weights, topk_ids = torch.max(scores.view( + scores.shape[0], local_num_group, -1), + dim=-1) + bias = torch.arange(0, + local_num_experts, + experts_per_group, + device=scores.device, + dtype=torch.int32).unsqueeze(0) + topk_ids = topk_ids.to(torch.int32) + bias + + else: + group_expert_indices = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device).view( + 1, 1, -1) + group_expert_offset = (torch.arange( + local_num_group, dtype=torch.int32, device=scores.device) * + experts_per_group).unsqueeze(0) + expert_index_range = torch.arange(experts_per_group, + dtype=torch.int32, + device=scores.device) + + scores_grouped = scores.view(num_tokens, local_num_group, + experts_per_group) + best_expert_idx = torch.argmax(scores_grouped, + dim=2) # (num_tokens, num_groups) + vote_mask = (best_expert_idx.unsqueeze(-1).to( + torch.int32) == group_expert_indices) + + expert_vote_freq = vote_mask.sum(dim=0) + + sorted_indices = torch.argsort(expert_vote_freq, + dim=1, + descending=True).to(torch.int32) + topk_experts = sorted_indices[:, :num_voted_experts] + keep_mask = (( + topk_experts.unsqueeze(-1) == expert_index_range).any( + dim=1)).unsqueeze(0) + + masked_scores = torch.where(keep_mask, scores_grouped, 0) + + topk_weights, best_pos_in_group = masked_scores.max(dim=2) + best_pos_in_group = best_pos_in_group.to(torch.int32) + topk_ids = (best_pos_in_group + group_expert_offset).to( + torch.int32) flatten_topk_ids = topk_ids.view(-1) router_weights = router_weights.index_select(0, flatten_topk_ids).view( @@ -382,6 +423,11 @@ class PanguProMoESparseMoeBlock(nn.Module): return topk_weights, topk_ids + return pangu_group8_topk + + +class PanguProMoESparseMoeBlock(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -397,14 +443,15 @@ class PanguProMoESparseMoeBlock(nn.Module): f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}.") - self.local_num_group = config.num_experts_per_tok // get_ep_group( - ).world_size self.num_experts_per_tok = config.num_experts_per_tok - self.local_num_experts = config.num_experts // get_ep_group( - ).world_size self.router_scale = torch.nn.Parameter( torch.ones((1, self.num_experts))) + # on 300I Duo platform, we find that num_voted_experts set to 5 achieves + # good performance without sacrifice too much accuracy. for other platform, + # this is set to 8 to use original pangu grouped topk. + num_voted_experts = 5 if is_310p() else 8 + self.experts = FusedMoE( num_experts=config.num_experts, top_k=config.num_experts_per_tok, @@ -412,8 +459,7 @@ class PanguProMoESparseMoeBlock(nn.Module): intermediate_size=config.moe_intermediate_size, reduce_results=False, quant_config=quant_config, - custom_routing_function=PanguProMoESparseMoeBlock. - pangu_group8_topk, + custom_routing_function=topk_wrapper(num_voted_experts), prefix=f"{prefix}.experts", ) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 3c84f23..4e21c74 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -21,7 +21,7 @@ import torch from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod -from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p, +from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, select_experts) from vllm_ascend.utils import is_310p @@ -58,9 +58,9 @@ def forward_oot( e_score_correction_bias=e_score_correction_bias, ) - if is_310p(): + if topk_ids.shape[1] < top_k or is_310p(): assert global_num_experts is not None - return fused_experts_310p( + return fused_experts_moge( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cad90a4..bea0dc5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -39,7 +39,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, + get_fused_moe_state, is_310p, npu_stream_switch, npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -548,8 +548,7 @@ def fused_experts_with_all2all_buffer( return final_hidden_states -# Currently, fused_experts on 310p only supports PanguProMoE. -def fused_experts_310p( +def fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -614,8 +613,11 @@ def fused_experts_310p( group_list=group_list, )[0] - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) gate_up_out *= topk_scales w2 = w2.transpose(1, 2) @@ -628,8 +630,7 @@ def fused_experts_310p( group_list=group_list, )[0] - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to( - torch.int32) + torch.Tensor([0]).to(torch.int32).npu() + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( bsz, top_k // ep_size, -1).sum(1)