[main] [bugfix] Fix misjudging quantized/unquantized scenarios (#2627)
### What this PR does / why we need it?
In a mixed-precision scenario, quant_config is not None, but MoE needs
to perform unquantized computation; however, quantized computation is
currently being used. Therefore, we put the with_quant logic into
forward, avoid misjudging in mix-precision scenarios.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
98ac0cb32d
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -490,7 +490,6 @@ class MoETokenDispatcher(ABC):
|
||||
"""
|
||||
self.top_k = kwargs.get("top_k", 0)
|
||||
self.num_experts = kwargs.get("num_experts", 0)
|
||||
self.with_quant = kwargs.get("with_quant", False)
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
@@ -518,7 +517,8 @@ class MoETokenDispatcher(ABC):
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
@@ -555,6 +555,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.topk_weights = None
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
@@ -615,7 +616,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.expert_map = expert_map
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights
|
||||
@@ -738,6 +741,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
self.expert_map = None
|
||||
self.topk_weights = None
|
||||
self.topk_ids = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -751,7 +755,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
@@ -922,7 +928,8 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
@@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
self.num_global_redundant_experts = kwargs.get(
|
||||
"num_global_redundant_experts", 0)
|
||||
@@ -1032,7 +1040,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
shared_gate_up: Optional[torch.Tensor] = None,
|
||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.topk_weights = topk_weights
|
||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||
|
||||
Reference in New Issue
Block a user