[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it?
moe multistream overlap to improve the performance.
### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
@@ -13,6 +13,10 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Mock FusedMoEConfig
|
# Mock FusedMoEConfig
|
||||||
|
fake_stream = MagicMock()
|
||||||
|
patcher = patch("torch.npu.Stream", return_value=fake_stream)
|
||||||
|
patcher.start()
|
||||||
|
self.addCleanup(patcher.stop)
|
||||||
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||||
self.moe_config.tp_group = MagicMock()
|
self.moe_config.tp_group = MagicMock()
|
||||||
self.moe_config.tp_group.device_group = MagicMock()
|
self.moe_config.tp_group.device_group = MagicMock()
|
||||||
|
|||||||
@@ -106,6 +106,8 @@ class AscendConfig:
|
|||||||
enable_shared_expert_dp=True)
|
enable_shared_expert_dp=True)
|
||||||
self.multistream_overlap_shared_expert = additional_config.get(
|
self.multistream_overlap_shared_expert = additional_config.get(
|
||||||
"multistream_overlap_shared_expert", False)
|
"multistream_overlap_shared_expert", False)
|
||||||
|
self.multistream_overlap_gate = additional_config.get(
|
||||||
|
"multistream_overlap_gate", False)
|
||||||
self.recompute_scheduler_enable = additional_config.get(
|
self.recompute_scheduler_enable = additional_config.get(
|
||||||
"recompute_scheduler_enable", False)
|
"recompute_scheduler_enable", False)
|
||||||
self.enable_cpu_binding = additional_config.get(
|
self.enable_cpu_binding = additional_config.get(
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None
|
|||||||
_LMTP: Optional[GroupCoordinator] = None
|
_LMTP: Optional[GroupCoordinator] = None
|
||||||
_EMBED_TP: Optional[GroupCoordinator] = None
|
_EMBED_TP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
# flashcomm2 specific groups
|
# flashcomm specific groups
|
||||||
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
|
||||||
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
|
||||||
|
_FC3_QUANT_X: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
# shared_weight across rank groups
|
# shared_weight across rank groups
|
||||||
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
|
||||||
@@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
|
||||||
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
|
||||||
|
|
||||||
|
if get_ascend_config().multistream_overlap_gate:
|
||||||
|
global _FC3_QUANT_X
|
||||||
|
group_ranks = all_ranks.unbind(0)
|
||||||
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
group_name="fc3_quant_x")
|
||||||
|
|
||||||
|
|
||||||
def model_parallel_initialized():
|
def model_parallel_initialized():
|
||||||
return (_MC2 is not None)
|
return (_MC2 is not None)
|
||||||
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
|
|||||||
return _P_TP
|
return _P_TP
|
||||||
|
|
||||||
|
|
||||||
|
def get_fc3_quant_x_group() -> GroupCoordinator:
|
||||||
|
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
|
||||||
|
return _FC3_QUANT_X
|
||||||
|
|
||||||
|
|
||||||
def destroy_ascend_model_parallel():
|
def destroy_ascend_model_parallel():
|
||||||
global _MC2
|
global _MC2
|
||||||
if _MC2:
|
if _MC2:
|
||||||
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
|
|||||||
if _SHARED_WEIGHT:
|
if _SHARED_WEIGHT:
|
||||||
_SHARED_WEIGHT.destroy()
|
_SHARED_WEIGHT.destroy()
|
||||||
_SHARED_WEIGHT = None
|
_SHARED_WEIGHT = None
|
||||||
|
|
||||||
|
global _FC3_QUANT_X
|
||||||
|
if _FC3_QUANT_X:
|
||||||
|
_FC3_QUANT_X.destroy()
|
||||||
|
_FC3_QUANT_X = None
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
from vllm_ascend.distributed.parallel_state import (get_dp_group,
|
||||||
|
get_fc3_quant_x_group,
|
||||||
|
get_p_tp_group)
|
||||||
|
|
||||||
|
|
||||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||||
@@ -59,3 +62,31 @@ def get_transfer_timeout_value():
|
|||||||
'7')) # type: ignore
|
'7')) # type: ignore
|
||||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||||
3000)
|
3000)
|
||||||
|
|
||||||
|
|
||||||
|
def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
except AssertionError:
|
||||||
|
return x
|
||||||
|
x = get_fc3_quant_x_group().all_gather(x, 0)
|
||||||
|
dp_metadata = forward_context.dp_metadata
|
||||||
|
if dp_metadata is None:
|
||||||
|
pad_size = forward_context.pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
x = x[:-pad_size]
|
||||||
|
else:
|
||||||
|
# unpad
|
||||||
|
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||||
|
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
dp_size = get_dp_group().world_size
|
||||||
|
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||||
|
offset = 0
|
||||||
|
for idx in range(dp_size):
|
||||||
|
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||||
|
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
|
||||||
|
offset += num_tokens_dp
|
||||||
|
x = result
|
||||||
|
return x
|
||||||
|
|||||||
42
vllm_ascend/flash_common3_context.py
Normal file
42
vllm_ascend/flash_common3_context.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlashCommon3Context:
|
||||||
|
gate: Optional[LinearBase] = None
|
||||||
|
topk_weights: Optional[torch.Tensor] = None
|
||||||
|
topk_ids: Optional[torch.Tensor] = None
|
||||||
|
row_idx: Optional[torch.Tensor] = None
|
||||||
|
shared_experts: Optional[torch.nn.Module] = None
|
||||||
|
shared_out: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
_flash_common3_context: Optional[FlashCommon3Context] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
|
||||||
|
return _flash_common3_context
|
||||||
|
|
||||||
|
|
||||||
|
def set_flash_common3_context(
|
||||||
|
topk_weights: Optional[torch.Tensor] = None,
|
||||||
|
topk_ids: Optional[torch.Tensor] = None,
|
||||||
|
shared_experts: Optional[torch.nn.Module] = None,
|
||||||
|
shared_out: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
global _flash_common3_context
|
||||||
|
if _flash_common3_context is None:
|
||||||
|
_flash_common3_context = FlashCommon3Context()
|
||||||
|
|
||||||
|
if topk_weights is not None:
|
||||||
|
_flash_common3_context.topk_weights = topk_weights
|
||||||
|
if topk_ids is not None:
|
||||||
|
_flash_common3_context.topk_ids = topk_ids
|
||||||
|
if shared_experts is not None:
|
||||||
|
_flash_common3_context.shared_experts = shared_experts
|
||||||
|
if shared_out is not None:
|
||||||
|
_flash_common3_context.shared_out = shared_out
|
||||||
@@ -37,9 +37,12 @@ from vllm_ascend.ascend_forward_context import MoECommType
|
|||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
|
||||||
from vllm_ascend.eplb.utils import moe_load_async_stream
|
from vllm_ascend.eplb.utils import moe_load_async_stream
|
||||||
|
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
||||||
|
set_flash_common3_context)
|
||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||||
|
setup_moe_comm_method)
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||||
AscendW4A8DynamicFusedMoEMethod
|
AscendW4A8DynamicFusedMoEMethod
|
||||||
@@ -139,6 +142,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
class AscendFusedMoE(FusedMoE):
|
||||||
moe_counter = -1
|
moe_counter = -1
|
||||||
|
gate_stream: Optional[torch.npu.Stream] = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -170,6 +174,10 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.expert_map_path = ascend_config.expert_map_path
|
self.expert_map_path = ascend_config.expert_map_path
|
||||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||||
|
# flashcommon3 gate stream
|
||||||
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||||
|
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
|
||||||
|
AscendFusedMoE.gate_stream = torch.npu.Stream()
|
||||||
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
||||||
@@ -332,6 +340,47 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
enable_force_load_balance = forward_context.in_profile_run
|
enable_force_load_balance = forward_context.in_profile_run
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
|
if self.multistream_overlap_gate:
|
||||||
|
assert AscendFusedMoE.gate_stream is not None
|
||||||
|
fc3_context = get_flash_common3_context()
|
||||||
|
assert fc3_context is not None
|
||||||
|
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
|
||||||
|
with npu_stream_switch(AscendFusedMoE.gate_stream,
|
||||||
|
enabled=self.multistream_overlap_gate):
|
||||||
|
# share_expert
|
||||||
|
assert fc3_context.shared_experts is not None
|
||||||
|
shared_out = fc3_context.shared_experts(hidden_states)
|
||||||
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||||
|
and not shared_expert_dp_enabled():
|
||||||
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
|
set_flash_common3_context(shared_out=shared_out)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
global_num_experts=self.global_num_experts)
|
||||||
|
|
||||||
|
if isinstance(forward_context.moe_comm_method,
|
||||||
|
AllGatherCommImpl):
|
||||||
|
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
topk_weights, True, True)
|
||||||
|
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
topk_ids, True, True)
|
||||||
|
|
||||||
|
set_flash_common3_context(topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids)
|
||||||
|
|
||||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -339,6 +388,10 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||||
quant_type=self.quant_type)
|
quant_type=self.quant_type)
|
||||||
|
|
||||||
|
# Make sure the default stream waits for the gate stream to finish.
|
||||||
|
if self.multistream_overlap_gate:
|
||||||
|
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
|
||||||
|
|
||||||
if isinstance(hidden_states, tuple):
|
if isinstance(hidden_states, tuple):
|
||||||
hidden_states, pertoken_scale = hidden_states
|
hidden_states, pertoken_scale = hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -407,6 +460,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
self.shared_expert_stream = None
|
self.shared_expert_stream = None
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||||
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||||
if enable_sp():
|
if enable_sp():
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||||
@@ -443,30 +497,42 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
|||||||
|
|
||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
|
shared_out = None
|
||||||
|
if not self.multistream_overlap_gate:
|
||||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
shared_experts_calculation_stream().wait_stream( # type: ignore
|
shared_experts_calculation_stream(
|
||||||
|
).wait_stream( # type: ignore
|
||||||
torch.npu.current_stream())
|
torch.npu.current_stream())
|
||||||
with npu_stream_switch(shared_experts_calculation_stream(),
|
with npu_stream_switch(
|
||||||
|
shared_experts_calculation_stream(),
|
||||||
enabled=self.multistream_overlap_shared_expert):
|
enabled=self.multistream_overlap_shared_expert):
|
||||||
# Use a separate stream to run shared experts.
|
# Use a separate stream to run shared experts.
|
||||||
# Note that currently we only support calculations in separate streams with aclgraph.
|
|
||||||
# Communication operations in another stream might cause unknown errors.
|
|
||||||
shared_out = self._shared_experts(hidden_states)
|
shared_out = self._shared_experts(hidden_states)
|
||||||
|
else:
|
||||||
|
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||||
|
|
||||||
fused_output = AscendFusedMoE.forward_impl(
|
fused_output = AscendFusedMoE.forward_impl(
|
||||||
self,
|
self,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.multistream_overlap_gate:
|
||||||
# Make sure the default stream waits for the shared experts stream to finish.
|
# Make sure the default stream waits for the shared experts stream to finish.
|
||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
torch.npu.current_stream().wait_stream(
|
torch.npu.current_stream().wait_stream(
|
||||||
shared_experts_calculation_stream())
|
shared_experts_calculation_stream())
|
||||||
|
|
||||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||||
and not shared_expert_dp_enabled():
|
and not shared_expert_dp_enabled():
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
|
else:
|
||||||
|
fc3_context = get_flash_common3_context()
|
||||||
|
assert fc3_context is not None
|
||||||
|
shared_out = fc3_context.shared_out
|
||||||
|
|
||||||
return shared_out, fused_output
|
return shared_out, fused_output
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||||
|
|
||||||
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||||
|
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
|
||||||
|
prefill_context_parallel_enable)
|
||||||
|
|
||||||
|
|
||||||
class QuantType(Enum):
|
class QuantType(Enum):
|
||||||
@@ -49,9 +52,14 @@ class PrepareAndFinalize(ABC):
|
|||||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||||
sizes, ranks, and communication settings.
|
sizes, ranks, and communication settings.
|
||||||
"""
|
"""
|
||||||
|
quant_stream: Optional[torch.npu.Stream] = None
|
||||||
|
|
||||||
def __init__(self, moe_config: FusedMoEConfig):
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
self.moe_config = moe_config
|
self.moe_config = moe_config
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||||
|
if self.multistream_overlap_gate and PrepareAndFinalize.quant_stream is None:
|
||||||
|
PrepareAndFinalize.quant_stream = torch.npu.Stream()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
@@ -335,13 +343,29 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
|||||||
if quant_type == QuantType.W8A8:
|
if quant_type == QuantType.W8A8:
|
||||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||||
hidden_states)
|
hidden_states)
|
||||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
pertoken_scale, True, True)
|
if self.multistream_overlap_gate:
|
||||||
|
assert PrepareAndFinalize.quant_stream is not None
|
||||||
|
PrepareAndFinalize.quant_stream.wait_stream(
|
||||||
|
torch.npu.current_stream())
|
||||||
|
with npu_stream_switch(PrepareAndFinalize.quant_stream,
|
||||||
|
enabled=self.multistream_overlap_gate):
|
||||||
|
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
|
||||||
|
hidden_states)
|
||||||
|
else:
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states, True, True)
|
hidden_states, True, True)
|
||||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
router_logits, True, True)
|
router_logits, True, True)
|
||||||
|
|
||||||
|
if pertoken_scale is not None:
|
||||||
|
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
pertoken_scale, True, True)
|
||||||
|
|
||||||
|
if self.multistream_overlap_gate:
|
||||||
|
torch.npu.current_stream().wait_stream(
|
||||||
|
PrepareAndFinalize.quant_stream)
|
||||||
|
|
||||||
if pertoken_scale is not None:
|
if pertoken_scale is not None:
|
||||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
|
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||||
|
|
||||||
@@ -114,6 +115,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
self.use_aclgraph = (vllm_config.compilation_config.mode
|
self.use_aclgraph = (vllm_config.compilation_config.mode
|
||||||
== CompilationMode.VLLM_COMPILE
|
== CompilationMode.VLLM_COMPILE
|
||||||
and not vllm_config.model_config.enforce_eager)
|
and not vllm_config.model_config.enforce_eager)
|
||||||
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
||||||
|
|
||||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||||
self.in_dtype = vllm_config.model_config.dtype
|
self.in_dtype = vllm_config.model_config.dtype
|
||||||
@@ -198,6 +200,13 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||||
|
|
||||||
|
topk_weights, topk_ids = None, None
|
||||||
|
if self.multistream_overlap_gate:
|
||||||
|
fc3_context = get_flash_common3_context()
|
||||||
|
assert fc3_context is not None
|
||||||
|
topk_weights = fc3_context.topk_weights
|
||||||
|
topk_ids = fc3_context.topk_ids
|
||||||
|
else:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -222,6 +231,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
topk_ids = torch.argsort(
|
topk_ids = torch.argsort(
|
||||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||||
|
|
||||||
|
assert topk_weights is not None
|
||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
|
|||||||
Reference in New Issue
Block a user