[main] [bugfix] Fix misjudging quantized/unquantized scenarios (#2627)

### What this PR does / why we need it?
In a mixed-precision scenario, quant_config is not None, but MoE needs
to perform unquantized computation; however, quantized computation is
currently being used. Therefore, we put the with_quant logic into
forward, avoid misjudging in mix-precision scenarios.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut

- vLLM version: v0.10.1.1
- vLLM main:
98ac0cb32d

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-08-29 16:20:22 +08:00
committed by GitHub
parent aadc75c247
commit 52aff9e229
7 changed files with 62 additions and 65 deletions

View File

@@ -408,19 +408,19 @@ def unquant_apply_mlp(
return hidden_states
def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
if get_forward_context().with_quant:
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
@@ -457,7 +457,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
@@ -472,7 +473,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=with_quant)
expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
@@ -485,7 +487,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"))
topk_scales=results.get("topk_scales"),
with_quant=with_quant)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states
@@ -577,7 +580,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map=expert_map,
shared_experts=shared_experts,
mc2_mask=kwargs.get(
"mc2_mask", None))
"mc2_mask", None),
with_quant=False)
class AscendFusedMoE(FusedMoE):
@@ -761,7 +765,6 @@ class AscendFusedMoE(FusedMoE):
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
with_quant = quant_config is not None
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
@@ -769,8 +772,7 @@ class AscendFusedMoE(FusedMoE):
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts,
with_quant=with_quant)
num_local_experts=self.local_num_experts)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):

View File

@@ -490,7 +490,6 @@ class MoETokenDispatcher(ABC):
"""
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
self.with_quant = kwargs.get("with_quant", False)
@property
def ep_group(self):
@@ -518,7 +517,8 @@ class MoETokenDispatcher(ABC):
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
@@ -555,6 +555,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False
def get_dispatch_mc2_kwargs(
self,
@@ -615,7 +616,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
@@ -738,6 +741,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False
def token_dispatch(self,
hidden_states: torch.Tensor,
@@ -751,7 +755,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
@@ -922,7 +928,8 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
@@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
@@ -1032,7 +1040,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"