[Bugfix] fix mtp profile run error where main model and mtp model use different quantization (#4102)
### What this PR does / why we need it?
In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we
initially placed the quantization type (quant_type) in the MoECommMethod
class. However, since MoECommMethod follows a singleton pattern, it
couldn't accommodate scenarios where different layers in the model might
use different quantization approaches (e.g., MTP modules using
floating-point computation while the main model employs quantized
computation).
In this PR, we've moved the quantization type to the AscendFusedMoe
class and pass it as a parameter to MoECommMethod.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
export HCCL_BUFFSIZE=1024
export VLLM_VERSION=0.11.0
vllm serve /home/data/DeepSeek-R1_w8a8/ \
--data-parallel-size 2 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--served-model-name dsv3 \
--max-model-len 32768 \
--max-num-batched-tokens 4096 \
--max-num-seqs 16 \
--quantization ascend \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
```
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from tests.ut.base import TestBase
|
|||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||||
AlltoAllCommImpl,
|
AlltoAllCommImpl,
|
||||||
MC2CommImpl)
|
MC2CommImpl)
|
||||||
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||||
|
|
||||||
|
|
||||||
class TestMoECommMethod(TestBase):
|
class TestMoECommMethod(TestBase):
|
||||||
@@ -67,7 +68,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False)
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out,
|
comm_impl.finalize(h_out,
|
||||||
@@ -115,7 +116,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False)
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||||
|
|
||||||
# Test finalize method
|
# Test finalize method
|
||||||
comm_impl.finalize(h_out,
|
comm_impl.finalize(h_out,
|
||||||
@@ -165,7 +166,7 @@ class TestMoECommMethod(TestBase):
|
|||||||
|
|
||||||
# Verify prepare was called with correct arguments
|
# Verify prepare was called with correct arguments
|
||||||
mock_pf_instance.prepare.assert_called_once_with(
|
mock_pf_instance.prepare.assert_called_once_with(
|
||||||
hidden_states, router_logits, False, False)
|
hidden_states, router_logits, False, False, QuantType.NONE)
|
||||||
|
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
|
||||||
|
|||||||
@@ -37,6 +37,11 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
||||||
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||||
|
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||||
|
AscendW4A8DynamicFusedMoEMethod
|
||||||
|
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||||
|
AscendW8A8DynamicFusedMoEMethod
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||||
is_enable_nz, npu_stream_switch,
|
is_enable_nz, npu_stream_switch,
|
||||||
shared_expert_dp_enabled,
|
shared_expert_dp_enabled,
|
||||||
@@ -289,7 +294,23 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
setup_moe_comm_method(self.moe_config, self.quant_method)
|
setup_moe_comm_method(self.moe_config)
|
||||||
|
self.quant_type = self._get_quant_type()
|
||||||
|
|
||||||
|
def _get_quant_type(self) -> QuantType:
|
||||||
|
quant_method = self.quant_method
|
||||||
|
if not hasattr(quant_method,
|
||||||
|
"quant_method") or quant_method.quant_method is None:
|
||||||
|
return QuantType.NONE
|
||||||
|
|
||||||
|
method = quant_method.quant_method
|
||||||
|
|
||||||
|
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
||||||
|
return QuantType.W8A8
|
||||||
|
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
||||||
|
return QuantType.W4A8
|
||||||
|
else:
|
||||||
|
return QuantType.NONE
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
self.expert_map = new_expert_map
|
self.expert_map = new_expert_map
|
||||||
@@ -334,7 +355,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
replace_allreduce=forward_context.sp_enabled,
|
replace_allreduce=forward_context.sp_enabled,
|
||||||
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||||
|
quant_type=self.quant_type)
|
||||||
|
|
||||||
if isinstance(hidden_states, tuple):
|
if isinstance(hidden_states, tuple):
|
||||||
hidden_states, pertoken_scale = hidden_states
|
hidden_states, pertoken_scale = hidden_states
|
||||||
|
|||||||
@@ -31,10 +31,6 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
|||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
|
||||||
AscendW4A8DynamicFusedMoEMethod
|
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
|
||||||
AscendW8A8DynamicFusedMoEMethod
|
|
||||||
|
|
||||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||||
|
|
||||||
@@ -44,54 +40,37 @@ def get_moe_comm_method(
|
|||||||
return _MoECommMethods.get(moe_comm_type, None)
|
return _MoECommMethods.get(moe_comm_type, None)
|
||||||
|
|
||||||
|
|
||||||
def setup_moe_comm_method(moe_config, quant_method):
|
def setup_moe_comm_method(moe_config):
|
||||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
|
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||||
moe_config, quant_method)
|
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
|
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||||
moe_config, quant_method)
|
|
||||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
|
|
||||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||||
moe_config, quant_method)
|
moe_config)
|
||||||
|
|
||||||
|
|
||||||
class MoECommMethod(ABC):
|
class MoECommMethod(ABC):
|
||||||
"""Base class for MoE communication methods."""
|
"""Base class for MoE communication methods."""
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self.model_type = get_current_vllm_config(
|
self.model_type = get_current_vllm_config(
|
||||||
).model_config.hf_config.model_type
|
).model_config.hf_config.model_type
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
|
||||||
self.token_dispatcher = self._get_token_dispatcher()
|
self.token_dispatcher = self._get_token_dispatcher()
|
||||||
self.quant_type = self._get_quant_type(quant_method)
|
|
||||||
self.with_quant = self.quant_type != QuantType.NONE
|
|
||||||
self.prepare_finalize = self._get_prepare_finalize()
|
self.prepare_finalize = self._get_prepare_finalize()
|
||||||
|
|
||||||
def _get_quant_type(self, quant_method) -> QuantType:
|
|
||||||
if not hasattr(quant_method,
|
|
||||||
"quant_method") or quant_method.quant_method is None:
|
|
||||||
return QuantType.NONE
|
|
||||||
|
|
||||||
method = quant_method.quant_method
|
|
||||||
|
|
||||||
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
|
||||||
return QuantType.W8A8
|
|
||||||
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
|
||||||
return QuantType.W4A8
|
|
||||||
else:
|
|
||||||
return QuantType.NONE
|
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False
|
replace_allreduce: bool = False,
|
||||||
|
quant_type: QuantType = QuantType.NONE,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||||
hidden_states, router_logits, enable_shared_expert_dp,
|
hidden_states, router_logits, enable_shared_expert_dp,
|
||||||
replace_allreduce)
|
replace_allreduce, quant_type)
|
||||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||||
|
|
||||||
def finalize(self,
|
def finalize(self,
|
||||||
@@ -112,6 +91,8 @@ class MoECommMethod(ABC):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int4_w4a8: bool = False,
|
||||||
global_num_experts: Optional[int] = None,
|
global_num_experts: Optional[int] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -151,15 +132,14 @@ class MoECommMethod(ABC):
|
|||||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||||
mc2_mask=mc2_mask,
|
mc2_mask=mc2_mask,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
with_quant=self.with_quant,
|
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||||
dynamic_eplb=dynamic_eplb,
|
dynamic_eplb=dynamic_eplb,
|
||||||
pertoken_scale=pertoken_scale)
|
pertoken_scale=pertoken_scale)
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||||
|
|
||||||
mlp_output = unified_apply_mlp(
|
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||||
hidden_states=permuted_hidden_states,
|
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
@@ -170,8 +150,9 @@ class MoECommMethod(ABC):
|
|||||||
w1_scale_bias=w1_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w2_scale_bias=w2_scale_bias,
|
||||||
topk_scales=topk_scales,
|
topk_scales=topk_scales,
|
||||||
with_quant=self.with_quant,
|
with_quant=use_int8_w8a8
|
||||||
fusion=self.quant_type == QuantType.W8A8,
|
or use_int4_w4a8,
|
||||||
|
fusion=use_int8_w8a8,
|
||||||
need_trans=need_trans,
|
need_trans=need_trans,
|
||||||
dynamic_eplb=dynamic_eplb)
|
dynamic_eplb=dynamic_eplb)
|
||||||
|
|
||||||
@@ -226,8 +207,7 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithAllGather(self.moe_config,
|
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||||
self.quant_type)
|
|
||||||
|
|
||||||
|
|
||||||
class MC2CommImpl(MoECommMethod):
|
class MC2CommImpl(MoECommMethod):
|
||||||
@@ -244,7 +224,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
return TokenDispatcherWithMC2()
|
return TokenDispatcherWithMC2()
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
|
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||||
|
|
||||||
|
|
||||||
class AlltoAllCommImpl(MoECommMethod):
|
class AlltoAllCommImpl(MoECommMethod):
|
||||||
@@ -264,7 +244,7 @@ class AlltoAllCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
|
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||||
|
|
||||||
|
|
||||||
class NaiveMulticastCommImpl(MoECommMethod):
|
class NaiveMulticastCommImpl(MoECommMethod):
|
||||||
@@ -293,5 +273,4 @@ class NaiveMulticastCommImpl(MoECommMethod):
|
|||||||
num_local_experts=self.moe_config.num_local_experts)
|
num_local_experts=self.moe_config.num_local_experts)
|
||||||
|
|
||||||
def _get_prepare_finalize(self):
|
def _get_prepare_finalize(self):
|
||||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
|
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||||
self.quant_type)
|
|
||||||
|
|||||||
@@ -53,11 +53,8 @@ class PrepareAndFinalize(ABC):
|
|||||||
sizes, ranks, and communication settings.
|
sizes, ranks, and communication settings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
moe_config: FusedMoEConfig,
|
|
||||||
quant_type: QuantType = QuantType.NONE):
|
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
self.quant_type = quant_type
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
@@ -65,7 +62,8 @@ class PrepareAndFinalize(ABC):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False
|
replace_allreduce: bool = False,
|
||||||
|
quant_type: QuantType = QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +77,7 @@ class PrepareAndFinalize(ABC):
|
|||||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||||
|
quant_type: none, w8a8 or w4a8
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
Tuple of:
|
||||||
@@ -117,10 +116,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
moe_config: FusedMoEConfig,
|
super().__init__(moe_config)
|
||||||
quant_type: QuantType = QuantType.NONE):
|
|
||||||
super().__init__(moe_config, quant_type)
|
|
||||||
self._restore_tp_across_dp()
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
def _restore_tp_across_dp(self):
|
def _restore_tp_across_dp(self):
|
||||||
@@ -133,7 +130,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False
|
replace_allreduce: bool = False,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -211,10 +209,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
moe_config: FusedMoEConfig,
|
super().__init__(moe_config)
|
||||||
quant_type: QuantType = QuantType.NONE):
|
|
||||||
super().__init__(moe_config, quant_type)
|
|
||||||
self._restore_tp_across_dp()
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
def _restore_tp_across_dp(self):
|
def _restore_tp_across_dp(self):
|
||||||
@@ -231,7 +227,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False
|
replace_allreduce: bool = False,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -312,6 +309,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -322,7 +320,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
Tuple of (global_hidden_states, global_router_logits, None)
|
Tuple of (global_hidden_states, global_router_logits, None)
|
||||||
"""
|
"""
|
||||||
if enable_sp():
|
if enable_sp():
|
||||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
return self._prepare_with_ep_group(hidden_states, router_logits,
|
||||||
|
quant_type)
|
||||||
|
|
||||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||||
enable_shared_expert_dp,
|
enable_shared_expert_dp,
|
||||||
@@ -332,10 +331,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
pertoken_scale = None
|
pertoken_scale = None
|
||||||
if self.quant_type == QuantType.W8A8:
|
if quant_type == QuantType.W8A8:
|
||||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
hidden_states)
|
hidden_states)
|
||||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
@@ -356,6 +356,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -500,7 +501,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
enable_shared_expert_dp: bool = False,
|
enable_shared_expert_dp: bool = False,
|
||||||
replace_allreduce: bool = False
|
replace_allreduce: bool = False,
|
||||||
|
quant_type=QuantType.NONE
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -386,6 +386,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
w2_scale_bias=layer.w2_scale_bias,
|
w2_scale_bias=layer.w2_scale_bias,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
use_int4_w4a8=True,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
|||||||
@@ -256,6 +256,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
use_int8_w8a8=True,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
log2phy=log2phy,
|
log2phy=log2phy,
|
||||||
global_redundant_expert_num=global_redundant_expert_num,
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
|||||||
Reference in New Issue
Block a user