[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)

What this PR does / why we need it?

there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:

Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?

No

How was this patch tested?

e2e&ut

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-09-24 11:29:59 +08:00
committed by GitHub
parent e7618d9414
commit 6aa4253798
14 changed files with 90 additions and 215 deletions

View File

@@ -17,14 +17,14 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional, Union
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -100,7 +97,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
@@ -119,7 +115,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)
return hidden_states
@@ -188,60 +183,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -277,45 +218,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
@@ -340,7 +242,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
@@ -361,16 +262,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states