[Perf] move quant before allgather in Allgather EP (#3420)
### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334
Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|----------|----------|----------|
| 4k | 375.21 | 364.99 |
| 16k | 1465.23 | 1421.75 |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -27,10 +27,14 @@ from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||
AscendW4A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
@@ -40,25 +44,43 @@ def get_moe_comm_method(
|
||||
return _MoECommMethods.get(moe_comm_type, None)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
def setup_moe_comm_method(moe_config, quant_method):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
|
||||
moe_config, quant_method)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
|
||||
moe_config, quant_method)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
|
||||
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
|
||||
moe_config)
|
||||
moe_config, quant_method)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
|
||||
self.model_type = get_current_vllm_config(
|
||||
).model_config.hf_config.model_type
|
||||
self.moe_config = moe_config
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.quant_type = self._get_quant_type(quant_method)
|
||||
self.with_quant = self.quant_type != QuantType.NONE
|
||||
self.prepare_finalize = self._get_prepare_finalize()
|
||||
|
||||
def _get_quant_type(self, quant_method) -> QuantType:
|
||||
if not hasattr(quant_method,
|
||||
"quant_method") or quant_method.quant_method is None:
|
||||
return QuantType.NONE
|
||||
|
||||
method = quant_method.quant_method
|
||||
|
||||
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
||||
return QuantType.W8A8
|
||||
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
||||
return QuantType.W4A8
|
||||
else:
|
||||
return QuantType.NONE
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -90,8 +112,6 @@ class MoECommMethod(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -109,10 +129,11 @@ class MoECommMethod(ABC):
|
||||
global_redundant_expert_num: int = 0,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None):
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.int8
|
||||
]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
@@ -130,28 +151,29 @@ class MoECommMethod(ABC):
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
with_quant=self.with_quant,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale)
|
||||
|
||||
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
|
||||
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
|
||||
|
||||
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=permuted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=expert_tokens,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=self.with_quant,
|
||||
fusion=self.quant_type == QuantType.W8A8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
final_hidden_states = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=context_metadata)
|
||||
@@ -204,7 +226,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config,
|
||||
self.quant_type)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
@@ -221,7 +244,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
@@ -241,7 +264,7 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
@@ -270,4 +293,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
|
||||
self.quant_type)
|
||||
|
||||
Reference in New Issue
Block a user