[CI] Fix CI Break: upstream adds routed_scaling_factor in forward_oot interface (#2675)

### What this PR does / why we need it?
Fix CI Break: upstream adds routed_scaling_factor in forward_oot
interface, vllm-ascend needs to adapt

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
E2E and UT

- vLLM version: v0.10.1.1
- vLLM main:
3e330fcb21

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
leo-pony
2025-09-01 19:02:50 +08:00
committed by GitHub
parent ea53f9076e
commit 0df059f41a
2 changed files with 73 additions and 3 deletions

View File

@@ -35,7 +35,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -246,7 +246,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
and not vllm_config.model_config.enforce_eager)
def forward_oot(
def forward_oot_v01011(
self,
layer: torch.nn.Module,
x: torch.Tensor,
@@ -278,6 +278,69 @@ def forward_oot(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=1.0,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
@@ -441,4 +504,8 @@ class AscendFusedMoE(FusedMoE):
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
UnquantizedFusedMoEMethod.forward_oot = forward_oot
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
else:
UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -40,6 +40,7 @@ def select_experts(hidden_states: torch.Tensor,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
is_unquantized: bool = False,
@@ -78,6 +79,7 @@ def select_experts(hidden_states: torch.Tensor,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts,
is_unquantized=is_unquantized)
@@ -180,6 +182,7 @@ def _select_experts_with_fusion_ops(
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1,
is_unquantized: bool = False):