diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 7265113..2c11e6d 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -35,7 +35,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ setup_token_dispatchers -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ @@ -246,7 +246,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs): and not vllm_config.model_config.enforce_eager) -def forward_oot( +def forward_oot_v01011( self, layer: torch.nn.Module, x: torch.Tensor, @@ -278,6 +278,69 @@ def forward_oot( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=1.0, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + if topk_ids.shape[1] < top_k or is_310p(): + assert global_num_experts is not None + return fused_experts_moge( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + moe_parallel_config=self.moe.moe_parallel_config, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + +def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: + + topk_weights, topk_ids, _ = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) @@ -441,4 +504,8 @@ class AscendFusedMoE(FusedMoE): UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading -UnquantizedFusedMoEMethod.forward_oot = forward_oot + +if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"): + UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011 +else: + UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/layers/experts_selector.py index 11524ac..c1f9312 100644 --- a/vllm_ascend/ops/layers/experts_selector.py +++ b/vllm_ascend/ops/layers/experts_selector.py @@ -40,6 +40,7 @@ def select_experts(hidden_states: torch.Tensor, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor=1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, is_unquantized: bool = False, @@ -78,6 +79,7 @@ def select_experts(hidden_states: torch.Tensor, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, global_num_experts=global_num_experts, is_unquantized=is_unquantized) @@ -180,6 +182,7 @@ def _select_experts_with_fusion_ops( num_expert_group: Optional[int], custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor=1.0, global_num_experts: int = -1, is_unquantized: bool = False):