[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)

What this PR does / why we need it?

there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:

Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?

No

How was this patch tested?

e2e&ut

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
weijinqian0
2025-09-24 11:29:59 +08:00
committed by GitHub
parent e7618d9414
commit 6aa4253798
14 changed files with 90 additions and 215 deletions

View File

@@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context) set_forward_context)
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
class FusedMoEState(Enum): class FusedMoEState(Enum):
@@ -101,21 +102,19 @@ def set_ascend_forward_context(
# due to multiple warmups before actual capturing # due to multiple warmups before actual capturing
forward_context.capturing = False forward_context.capturing = False
# set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature. # set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ sp_enabled = enable_sp() and \
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \
tp_world_size > 1 and \ tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000 num_tokens is not None and num_tokens > 1000
if flashcomm_v1_enabled: if sp_enabled:
pad_size = (tp_world_size - pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size (num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
# set this for rope forward_oot using # set this for rope forward_oot using
forward_context.is_first_layer = True forward_context.is_first_layer = True

View File

@@ -163,7 +163,6 @@ class AscendMetadata:
# *************************** Other Properties *************************** # # *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False enable_dbo_across_dp: bool = False
is_only_prefill: bool = False
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
@@ -236,8 +235,7 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_mask=attn_mask, attn_mask=attn_mask,
attn_state=attn_state, attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata return attn_metadata
def build_for_graph_capture( def build_for_graph_capture(

View File

@@ -17,14 +17,14 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py # Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from typing import Optional, Union from typing import Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group) get_tp_group)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index, PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -100,7 +97,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
self, self,
hidden_states, hidden_states,
attn_metadata=None, attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
): ):
if attn_metadata is None: if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
@@ -119,7 +115,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
top_k=self.top_k, top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
shared_experts=None, shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
) )
return hidden_states return hidden_states
@@ -188,60 +183,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile @support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel): class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -277,45 +218,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
@@ -340,7 +242,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = [] self.expert_weights: list[torch.Tensor] = []
@@ -361,16 +262,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.num_moe_layers = len(self.moe_layers) self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1 self.num_expert_groups = 1
self.num_shared_experts = 0 self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -216,7 +216,9 @@ class AscendFusedMoE(FusedMoE):
forward_context = get_forward_context() forward_context = get_forward_context()
hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits) hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(

View File

@@ -21,8 +21,7 @@ from typing import Any, Callable, Optional
import torch import torch
import torch_npu import torch_npu
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group) get_tp_group)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -42,7 +41,6 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state, get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p, get_rm_router_logits_state, is_310p,
@@ -360,8 +358,7 @@ class AscendFusedMoE(FusedMoE):
top_k: Optional[int] = None, top_k: Optional[int] = None,
shared_experts: Optional[Any] = None, shared_experts: Optional[Any] = None,
gate=None, gate=None,
replace_allreduce: bool = False, replace_allreduce: bool = False):
_metadata_for_padding: Optional[MetadataForPadding] = None):
assert self.quant_method is not None assert self.quant_method is not None
@@ -379,13 +376,7 @@ class AscendFusedMoE(FusedMoE):
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states) shared_hidden_states = shared_experts(hidden_states)
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill if forward_context.sp_enabled:
tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True replace_allreduce = True
hidden_states, router_logits = forward_context.moe_comm_method.prepare( hidden_states, router_logits = forward_context.moe_comm_method.prepare(

View File

@@ -48,8 +48,9 @@ from vllm.distributed.parallel_state import get_tp_group
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable, from vllm_ascend.utils import (dense_optim_enable, enable_sp,
mlp_tp_enable, oproj_tp_enable) matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
class CustomTensorParallelOp: class CustomTensorParallelOp:
@@ -82,10 +83,17 @@ class CustomTensorParallelOp:
self.skip_bias_add = self.layer.skip_bias_add self.skip_bias_add = self.layer.skip_bias_add
self.return_bias = self.layer.return_bias self.return_bias = self.layer.return_bias
self.quant_method = self.layer.quant_method self.quant_method = self.layer.quant_method
self.prefix = self.layer.prefix
def apply_impl(self, input_):
raise NotImplementedError
# Replace layer.forward to customize the layer computation process. # Replace layer.forward to customize the layer computation process.
def apply(self, input_): def apply(self, input_):
raise NotImplementedError output, output_bias = self.apply_impl(input_)
if not self.return_bias:
return output
return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp): class CustomColumnParallelOp(CustomTensorParallelOp):
@@ -113,6 +121,14 @@ class CustomRowParallelOp(CustomTensorParallelOp):
self.reduce_results = self.layer.reduce_results self.reduce_results = self.layer.reduce_results
self.input_size_per_partition = self.layer.input_size_per_partition self.input_size_per_partition = self.layer.input_size_per_partition
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if dense_optim_enable():
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
if not self.return_bias:
return output
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp): class MLPColumnParallelOp(CustomColumnParallelOp):
@@ -123,7 +139,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
def comm_group(self): def comm_group(self):
return get_mlp_tp_group() return get_mlp_tp_group()
def apply( def apply_impl(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -134,14 +150,12 @@ class MLPColumnParallelOp(CustomColumnParallelOp):
output = self.quant_method.apply(self.layer, input_parallel, bias) output = self.quant_method.apply(self.layer, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp): class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
def apply( def apply_impl(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@@ -164,18 +178,16 @@ class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp):
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
class DenseOptimQKVParallelOp(CustomColumnParallelOp): class SequenceQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer, prefix): def __init__(self, layer, prefix):
super().__init__(layer) super().__init__(layer)
self.prefix = prefix self.prefix = prefix
def apply( def apply_impl(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@@ -201,8 +213,6 @@ class DenseOptimQKVParallelOp(CustomColumnParallelOp):
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
@@ -215,7 +225,7 @@ class MLPRowParallelOp(CustomRowParallelOp):
def comm_group(self): def comm_group(self):
return get_mlp_tp_group() return get_mlp_tp_group()
def apply( def apply_impl(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
@@ -234,8 +244,6 @@ class MLPRowParallelOp(CustomRowParallelOp):
output = self.comm_group.reduce_scatter(output_parallel, 0) output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
@@ -248,7 +256,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
def comm_group(self): def comm_group(self):
return get_otp_group() return get_otp_group()
def apply( def apply_impl(
self, self,
input_: torch.Tensor, input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
@@ -294,8 +302,6 @@ class OProjRowParallelOp(CustomRowParallelOp):
# Handle bias return based on configuration # Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def update_attrs(self): def update_attrs(self):
@@ -311,7 +317,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
super().__init__(layer) super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply( def apply_impl(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
@@ -335,8 +341,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
bias=bias_) bias=bias_)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
@classmethod @classmethod
@@ -359,13 +363,13 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
self.weight_t = self.layer.weight.t() self.weight_t = self.layer.weight.t()
class DenseOptimRowParallelOp(CustomRowParallelOp): class SequenceRowParallelOp(CustomRowParallelOp):
def __init__(self, layer, prefix): def __init__(self, layer, prefix):
super().__init__(layer) super().__init__(layer)
self.prefix = prefix self.prefix = prefix
def apply( def apply_impl(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@@ -391,12 +395,8 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
input_parallel, input_parallel,
bias=bias_) bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def update_attrs(self): def update_attrs(self):
@@ -407,23 +407,22 @@ class DenseOptimRowParallelOp(CustomRowParallelOp):
def get_column_parallel_op( def get_column_parallel_op(
disable_tp, prefix, layer disable_tp, prefix, layer
) -> Tuple[ ) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp, SequenceQKVParallelOp]], int, int]:
DenseOptimQKVParallelOp]], int, int]:
if disable_tp: if disable_tp:
return None, 0, 1 return None, 0, 1
custom_op: Optional[Union[ custom_op: Optional[Union[
MLPColumnParallelOp, MLPColumnParallelOp,
DenseOptimMergedColumnParallelOp, SequenceMergedColumnParallelOp,
DenseOptimQKVParallelOp, SequenceQKVParallelOp,
]] = None ]] = None
if "gate_up_proj" in prefix and mlp_tp_enable(): if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer) custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and dense_optim_enable(): elif "gate_up_proj" in prefix and enable_sp():
custom_op = DenseOptimMergedColumnParallelOp(layer) custom_op = SequenceMergedColumnParallelOp(layer)
elif dense_optim_enable(): elif enable_sp():
custom_op = DenseOptimQKVParallelOp(layer, prefix) custom_op = SequenceQKVParallelOp(layer, prefix)
if custom_op is not None: if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size return custom_op, custom_op.tp_rank, custom_op.tp_size
@@ -435,21 +434,21 @@ def get_row_parallel_op(
disable_tp, prefix, layer disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, ) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp, MatmulAllreduceRowParallelOp,
DenseOptimRowParallelOp]], int, int]: SequenceRowParallelOp]], int, int]:
if disable_tp: if disable_tp:
return None, 0, 1 return None, 0, 1
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp, MatmulAllreduceRowParallelOp,
DenseOptimRowParallelOp]] = None SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable(): if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer) custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable(): elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer) custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable(): elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer) custom_op = MatmulAllreduceRowParallelOp(layer)
elif dense_optim_enable(): elif enable_sp():
custom_op = DenseOptimRowParallelOp(layer, prefix) custom_op = SequenceRowParallelOp(layer, prefix)
if custom_op is not None: if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size return custom_op, custom_op.tp_rank, custom_op.tp_size

View File

@@ -133,11 +133,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
""" """
self.replace_allreduce = replace_allreduce self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp self.enable_shared_expert_dp = enable_shared_expert_dp
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
if not self.replace_allreduce: if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens pad_size = target_pad_length - self.num_tokens
@@ -149,23 +153,16 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
(0, 0, 0, pad_size)) (0, 0, 0, pad_size))
# Slice across TP ranks # Slice across TP ranks
if self.tp_size > 1: if self.tp_size > 1 and not self.enable_shared_expert_dp:
if not self.enable_shared_expert_dp: split_hidden_states = torch.tensor_split(hidden_states,
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size,
self.tp_size, dim=0)
dim=0) split_router_logits = torch.tensor_split(router_logits,
split_router_logits = torch.tensor_split(router_logits, self.tp_size,
self.tp_size, dim=0)
dim=0) hidden_states = split_hidden_states[self.tp_rank]
hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] self.split_hidden_states = split_hidden_states # Save for finalize
self.split_hidden_states = split_hidden_states # Save for finalize
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask,
self.tp_size,
dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits, mc2_mask return hidden_states, router_logits, mc2_mask

View File

@@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
return residual return residual
if x.size(0) != residual.size(0): if x.size(0) != residual.size(0):
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled sp_enabled = forward_context.sp_enabled
assert flashcomm_v1_enabled is True, ( assert sp_enabled is True, ("Currently, this situation only occurs "
"Currently, this situation only occurs " "when sp is enabled")
"when flashcomm_v1 is enabled")
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size)) residual = F.pad(residual, (0, 0, 0, pad_size))
@@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
except AssertionError: except AssertionError:
return x return x
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled sp_enabled = forward_context.sp_enabled
if flashcomm_v1_enabled and label: if sp_enabled and label:
x = tensor_model_parallel_all_gather(x, 0) x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
@@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
except AssertionError: except AssertionError:
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled sp_enabled = forward_context.sp_enabled
if flashcomm_v1_enabled: if sp_enabled:
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size)) x = F.pad(x, (0, 0, 0, pad_size))

View File

@@ -282,12 +282,6 @@ class NPUPlatform(Platform):
ascend_config.ascend_scheduler_config) ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config vllm_config.scheduler_config = ascend_scheduler_config
if compilation_config.pass_config.enable_sequence_parallelism:
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
raise NotImplementedError(
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
)
@classmethod @classmethod
def get_attn_backend_cls(cls, def get_attn_backend_cls(cls,
selected_backend, selected_backend,

View File

@@ -54,8 +54,8 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp) init_metadata_for_sp)
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):

View File

@@ -44,8 +44,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map) determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state, get_all_reduce_merge_state,

View File

@@ -590,6 +590,14 @@ def dense_optim_enable() -> bool:
return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
def enable_sp() -> bool:
from vllm.config import get_cached_compilation_config
return (
get_cached_compilation_config().pass_config.enable_sequence_parallelism
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
def is_moe_model(vllm_config: VllmConfig): def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict()) return any('experts' in key.lower() for key in config.to_dict())

View File

@@ -1582,7 +1582,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
positions.shape[0]) positions.shape[0])
if get_forward_context().flashcomm_v1_enabled: if get_forward_context().sp_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
pad_size = get_forward_context().pad_size pad_size = get_forward_context().pad_size
if pad_size > 0: if pad_size > 0: