[Perf] move quant before allgather in Allgather EP (#3420)
### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334
Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|----------|----------|----------|
| 4k | 375.21 | 364.99 |
| 16k | 1465.23 | 1421.75 |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -15,11 +15,13 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_tensor_model_parallel_rank,
|
||||
@@ -30,6 +32,12 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
NONE = 0
|
||||
W8A8 = 1
|
||||
W4A8 = 2
|
||||
|
||||
|
||||
class PrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
@@ -42,8 +50,11 @@ class PrepareAndFinalize(ABC):
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
self.moe_config = moe_config
|
||||
self.quant_type = quant_type
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
@@ -103,8 +114,10 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
super().__init__(moe_config, quant_type)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
@@ -195,8 +208,10 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
def __init__(self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_type: QuantType = QuantType.NONE):
|
||||
super().__init__(moe_config, quant_type)
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
@@ -316,11 +331,20 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
pertoken_scale = None
|
||||
if self.quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
pertoken_scale, True, True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
|
||||
Reference in New Issue
Block a user