[Bugfix] fix mtp profile run error where main model and mtp model use different quantization (#4102)

### What this PR does / why we need it?
In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we
initially placed the quantization type (quant_type) in the MoECommMethod
class. However, since MoECommMethod follows a singleton pattern, it
couldn't accommodate scenarios where different layers in the model might
use different quantization approaches (e.g., MTP modules using
floating-point computation while the main model employs quantized
computation).
In this PR, we've moved the quantization type to the AscendFusedMoe
class and pass it as a parameter to MoECommMethod.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
export HCCL_BUFFSIZE=1024
export VLLM_VERSION=0.11.0

vllm serve /home/data/DeepSeek-R1_w8a8/ \
 --data-parallel-size 2 \
 --tensor-parallel-size 8 \
 --enable-expert-parallel \
 --served-model-name dsv3 \
 --max-model-len 32768 \
 --max-num-batched-tokens 4096 \
 --max-num-seqs 16 \
 --quantization ascend \
 --trust-remote-code \
 --gpu-memory-utilization 0.9 \
 --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
```


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-11-13 11:02:31 +08:00
committed by GitHub
parent 17259cb265
commit 5093192769
6 changed files with 82 additions and 76 deletions

View File

@@ -53,11 +53,8 @@ class PrepareAndFinalize(ABC):
sizes, ranks, and communication settings.
"""
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
self.quant_type = quant_type
@abstractmethod
def prepare(
@@ -65,7 +62,8 @@ class PrepareAndFinalize(ABC):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -79,6 +77,7 @@ class PrepareAndFinalize(ABC):
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
replace_allreduce (bool): Bypass default all-reduce behavior
quant_type: none, w8a8 or w4a8
Returns:
Tuple of:
@@ -117,10 +116,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
@@ -133,7 +130,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -211,10 +209,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
@@ -231,7 +227,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -312,6 +309,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -322,7 +320,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits)
return self._prepare_with_ep_group(hidden_states, router_logits,
quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
@@ -332,10 +331,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
pertoken_scale = None
if self.quant_type == QuantType.W8A8:
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -356,6 +356,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -500,7 +501,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""