【main】SP For Qwen3 MoE (#2209)

### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
9edd1db02b

---------

Signed-off-by: libaokui <libaokui@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
lbk-sys
2025-08-07 09:15:49 +08:00
committed by GitHub
parent 57b9f02185
commit c611291661
11 changed files with 299 additions and 11 deletions

View File

@@ -284,6 +284,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \

View File

@@ -234,3 +234,27 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC():
}, },
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(prompts, max_tokens) vllm_model.generate_greedy(prompts, max_tokens)
def test_sp_for_qwen3_moe() -> None:
example_prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner(
snapshot_download("Qwen/Qwen3-30B-A3B"),
dtype="auto",
tensor_parallel_size=2,
distributed_executor_backend="mp",
compilation_config={
"pass_config": {
"enable_sequence_parallelism": True
}
},
enable_expert_parallel=True,
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -26,6 +26,7 @@ class TestNPUPlatform(TestBase):
self.mock_vllm_config.cache_config = MagicMock() self.mock_vllm_config.cache_config = MagicMock()
self.mock_vllm_config.scheduler_config = MagicMock() self.mock_vllm_config.scheduler_config = MagicMock()
self.mock_vllm_config.speculative_config = None self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
self.mock_ascend_config = MagicMock() self.mock_ascend_config = MagicMock()
self.mock_ascend_config.torchair_graph_config.enabled = False self.mock_ascend_config.torchair_graph_config.enabled = False

View File

@@ -151,6 +151,7 @@ class AscendMetadata:
slot_mapping: torch.Tensor = None slot_mapping: torch.Tensor = None
enable_dbo_across_dp: bool = False enable_dbo_across_dp: bool = False
is_only_prefill: bool = False
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
@@ -166,7 +167,8 @@ class AscendAttentionMetadataBuilder:
num_reqs, num_reqs,
num_actual_tokens, num_actual_tokens,
max_query_len, max_query_len,
enable_dbo_across_dp: bool = False): enable_dbo_across_dp: bool = False,
is_only_prefill: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor( block_table = self.runner.input_batch.block_table[0].get_device_tensor(
) )
@@ -203,7 +205,8 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_mask=attn_mask, attn_mask=attn_mask,
attn_state=attn_state, attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp) enable_dbo_across_dp=enable_dbo_across_dp,
is_only_prefill=is_only_prefill)
return attn_metadata return attn_metadata

View File

@@ -223,7 +223,9 @@ class AscendAttentionTorchairMetadataBuilder:
num_actual_tokens, num_actual_tokens,
max_query_len, max_query_len,
graph_pad_size: int = -1, graph_pad_size: int = -1,
enable_dbo_across_dp: bool = False): enable_dbo_across_dp: bool = False,
*args,
**kwargs):
device = self.runner.device device = self.runner.device

View File

@@ -384,6 +384,8 @@ class AscendMLAMetadataBuilder:
graph_pad_size: int = -1, graph_pad_size: int = -1,
query_start_loc: torch.Tensor = None, query_start_loc: torch.Tensor = None,
enable_dbo_across_dp: bool = False, enable_dbo_across_dp: bool = False,
*args,
**kwargs,
) -> AscendMLAMetadata: ) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs

View File

@@ -16,14 +16,15 @@
# limitations under the License. # limitations under the License.
# Adapted from vllm/model_executor/models/qwen3_moe.py # Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from typing import Optional
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group) get_tp_group)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -44,8 +45,11 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index, PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -96,6 +100,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
self, self,
hidden_states, hidden_states,
attn_metadata=None, attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
): ):
if attn_metadata is None: if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
@@ -114,6 +119,7 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
top_k=self.top_k, top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
shared_experts=None, shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
) )
return hidden_states return hidden_states
@@ -155,14 +161,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers) config.mlp_only_layers)
use_aclgraph = (vllm_config is not None self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager) and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
if not use_aclgraph: if not self.use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph. # FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config, self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config, quant_config=quant_config,
@@ -182,6 +188,60 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile @support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel): class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -216,6 +276,45 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = { packed_modules_mapping = {
@@ -253,6 +352,7 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = [] self.expert_weights: list[torch.Tensor] = []
@@ -273,3 +373,16 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.num_moe_layers = len(self.moe_layers) self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1 self.num_expert_groups = 1
self.num_shared_experts = 0 self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -47,6 +47,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state, get_all_reduce_merge_state,
@@ -1347,7 +1348,8 @@ class AscendFusedMoE(FusedMoE):
top_k: Optional[int] = None, top_k: Optional[int] = None,
shared_experts: Optional[Any] = None, shared_experts: Optional[Any] = None,
gate=None, gate=None,
replace_allreduce: bool = False): replace_allreduce: bool = False,
_metadata_for_padding: Optional[MetadataForPadding] = None):
assert self.quant_method is not None assert self.quant_method is not None
@@ -1381,7 +1383,17 @@ class AscendFusedMoE(FusedMoE):
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states) shared_hidden_states = shared_experts(hidden_states)
mc2_mask = forward_context.mc2_mask
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True
if (fused_moe_state not in [ if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast FusedMoEState.NaiveMulticast

View File

@@ -0,0 +1,120 @@
import torch
from torch.nn import functional as F
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm_ascend.platform import NPUPlatform
class MetadataForPadding:
def __init__(self,
padding_flag=False,
lengths_sum_padding=0,
lengths_sum_unpadding=0,
pad_size=0,
not_dummy_and_is_prefill=False):
self.padding_flag = padding_flag
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
self.lengths_sum_padding = lengths_sum_padding
self.lengths_sum_unpadding = lengths_sum_unpadding
self.pad_size = pad_size
self.tp_size = get_tp_group().world_size
self.tp_rank_in_group = get_tp_group().rank_in_group
assert self.lengths_sum_padding % self.tp_size == 0
self.slice_size = self.lengths_sum_padding // self.tp_size
self.mc2_mask = torch.zeros(
self.lengths_sum_padding,
dtype=torch.bool,
device=NPUPlatform.device_type,
)
self.mc2_mask[:lengths_sum_unpadding] = True
def padding_aligned_reduce_scatter(self,
data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
padded_data, 0)
return padded_data_reduce_scatter
def allgather_unpadding_aligned(self,
padded_data: torch.Tensor) -> torch.Tensor:
padded_data_allgather = tensor_model_parallel_all_gather(
padded_data, 0)
if self.padding_flag:
lengths_sum_unpadding = self.lengths_sum_unpadding
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
else:
unpadding_data = padded_data_allgather
return unpadding_data
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
start = self.tp_rank_in_group * self.slice_size
end = start + self.slice_size
slice_data = padded_data[start:end]
return slice_data
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
# padded_data = data
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
return padded_data_reduce_scatter
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
if not enable_sequence_parallelism:
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)
is_perifll = 0
attn_metadata = get_forward_context().attn_metadata
tp_size = get_tensor_model_parallel_world_size()
if attn_metadata is not None:
if hasattr(attn_metadata,
'is_only_prefill') and attn_metadata.is_only_prefill:
is_perifll = 1
if hasattr(attn_metadata,
'num_prefills') and attn_metadata.num_prefills > 0:
is_perifll = 1
if is_perifll:
lengths_sum_unpadding = input_ids.shape[0]
lengths_sum_padding = (
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
if lengths_sum_unpadding == lengths_sum_padding:
padding_flag = False
else:
padding_flag = True
pad_size = lengths_sum_padding - lengths_sum_unpadding
_metadata_for_padding = MetadataForPadding(
lengths_sum_unpadding=lengths_sum_unpadding,
lengths_sum_padding=lengths_sum_padding,
padding_flag=padding_flag,
pad_size=pad_size,
not_dummy_and_is_prefill=True)
return _metadata_for_padding
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)

View File

@@ -195,6 +195,12 @@ class NPUPlatform(Platform):
ascend_config.ascend_scheduler_config) ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config vllm_config.scheduler_config = ascend_scheduler_config
if compilation_config.pass_config.enable_sequence_parallelism:
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
raise NotImplementedError(
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
)
# register Ascend CustomOp # register Ascend CustomOp
register_ascend_customop() register_ascend_customop()

View File

@@ -1160,6 +1160,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill = attn_state not in [ with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
] ]
is_only_prefill = bool(np.all(num_valid_tokens != 1))
extra_builder_kwargs['is_only_prefill'] = is_only_prefill
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state, attn_state,
total_num_scheduled_tokens) total_num_scheduled_tokens)