[Refactor] [SP]The sequence parallelism characteristics in the MoE and Dense models are integrated into a single solution. (#3085)
What this PR does / why we need it?
there are two sets of sp implementations for moe and dense models. One
is called sequence_parallelism, and the other is flashcomm_v1.
We did the following things:
Merge two sets of code with the same implementation into one.
Remove the implementation of sequence_parallelism, as this solution
cannot support aclgraph.
Does this PR introduce any user-facing change?
No
How was this patch tested?
e2e&ut
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -17,14 +17,14 @@
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
@@ -100,7 +97,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
@@ -119,7 +115,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@@ -188,60 +183,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.
|
||||
enable_sequence_parallelism if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
@@ -277,45 +218,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
|
||||
@@ -340,7 +242,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
@@ -361,16 +262,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user