[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #11) (#6176)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 15:28:49 +08:00
committed by GitHub
parent 4fb3d5e1b2
commit 65b7f716e6
8 changed files with 694 additions and 784 deletions

View File

@@ -16,22 +16,23 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed.parallel_state import (
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
prefill_context_parallel_enable)
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
class QuantType(Enum):
@@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC):
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
quant_stream: Optional[torch.npu.Stream] = None
quant_stream: torch.npu.Stream | None = None
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC):
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
@@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape
}
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
return hidden_states, router_logits, None, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
@@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata[
"padded_hidden_states_shape"]
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
split_hidden_states = torch.tensor_split(
gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
)
split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group)
hidden_states = gathered_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
hidden_states = hidden_states[: self.num_tokens]
return hidden_states
@@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
# Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
# Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
@@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
@@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits,
quant_type)
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce)
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
def _prepare_with_ep_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
pertoken_scale = None
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if self.multistream_overlap_gate:
assert PrepareAndFinalize.quant_stream is not None
PrepareAndFinalize.quant_stream.wait_stream(
torch.npu.current_stream())
with npu_stream_switch(PrepareAndFinalize.quant_stream,
enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
hidden_states)
PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states)
else:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True)
if pertoken_scale is not None:
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
pertoken_scale, True, True)
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True)
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(
PrepareAndFinalize.quant_stream)
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
@@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
@@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return hidden_states, router_logits, None, None
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalization steps:
Reduce Scatter hidden states.
@@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self,
hidden_states: torch.Tensor) -> torch.Tensor:
def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to
@@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe.
"""
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states, True)
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True)
return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
@@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
hidden_states = hidden_states[: self.num_tokens]
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
dim=0)
hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0)
return hidden_states