### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -16,22 +16,23 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
@@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC):
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
quant_stream: Optional[torch.npu.Stream] = None
|
||||
|
||||
quant_stream: torch.npu.Stream | None = None
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
@@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
@@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
@@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
@@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape
|
||||
}
|
||||
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
@@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
# may share memory with original hidden_states. Since shared
|
||||
# experts may use the original tensor, reusing it would cause
|
||||
# in-place modification during all_gather, corrupting the data.
|
||||
padded_hidden_states_shape = context_metadata[
|
||||
"padded_hidden_states_shape"]
|
||||
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
|
||||
gathered_hidden_states = torch.empty(
|
||||
padded_hidden_states_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
split_hidden_states = torch.tensor_split(
|
||||
gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group)
|
||||
hidden_states = gathered_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
hidden_states = hidden_states[: self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
@@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
@@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
@@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits,
|
||||
quant_type)
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
pertoken_scale = None
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
assert PrepareAndFinalize.quant_stream is not None
|
||||
PrepareAndFinalize.quant_stream.wait_stream(
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(PrepareAndFinalize.quant_stream,
|
||||
enabled=self.multistream_overlap_gate):
|
||||
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
|
||||
hidden_states)
|
||||
PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream())
|
||||
with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate):
|
||||
hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
pertoken_scale, True, True)
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
PrepareAndFinalize.quant_stream)
|
||||
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||
@@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
@@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
@@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
@@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
@@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
@@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
hidden_states = hidden_states[: self.num_tokens]
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||
dim=0)
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user