[Bugfix] fix mtp profile run error where main model and mtp model use different quantization (#4102)
### What this PR does / why we need it?
In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we
initially placed the quantization type (quant_type) in the MoECommMethod
class. However, since MoECommMethod follows a singleton pattern, it
couldn't accommodate scenarios where different layers in the model might
use different quantization approaches (e.g., MTP modules using
floating-point computation while the main model employs quantized
computation).
In this PR, we've moved the quantization type to the AscendFusedMoe
class and pass it as a parameter to MoECommMethod.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
export HCCL_BUFFSIZE=1024
export VLLM_VERSION=0.11.0
vllm serve /home/data/DeepSeek-R1_w8a8/ \
--data-parallel-size 2 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--served-model-name dsv3 \
--max-model-len 32768 \
--max-num-batched-tokens 4096 \
--max-num-seqs 16 \
--quantization ascend \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
```
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -53,11 +53,8 @@ class PrepareAndFinalize(ABC):
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
self.quant_type = quant_type
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
@@ -65,7 +62,8 @@ class PrepareAndFinalize(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -79,6 +77,7 @@ class PrepareAndFinalize(ABC):
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
quant_type: none, w8a8 or w4a8
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
@@ -117,10 +116,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
super().__init__(moe_config, quant_type)
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
@@ -133,7 +130,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -211,10 +209,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
super().__init__(moe_config, quant_type)
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
@@ -231,7 +227,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -312,6 +309,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -322,7 +320,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits,
|
||||
quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
@@ -332,10 +331,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
pertoken_scale = None
|
||||
if self.quant_type == QuantType.W8A8:
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
@@ -356,6 +356,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -500,7 +501,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user