diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e5ce07f..8888b70 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend +from vllm_ascend.utils import enable_sp class FusedMoEState(Enum): @@ -101,21 +102,19 @@ def set_ascend_forward_context( # due to multiple warmups before actual capturing forward_context.capturing = False - # set for flashcomm_v1, 1000 is the batchsize concurrency threshold for enabling the flashcomm_v1 feature. + # set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature. # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ - envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ + sp_enabled = enable_sp() and \ tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 - if flashcomm_v1_enabled: + if sp_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size - - forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled + forward_context.sp_enabled = sp_enabled # set this for rope forward_oot using forward_context.is_first_layer = True diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 511d3ad..963a947 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -163,7 +163,6 @@ class AscendMetadata: # *************************** Other Properties *************************** # enable_dbo_across_dp: bool = False - is_only_prefill: bool = False class AscendAttentionMetadataBuilder: @@ -236,8 +235,7 @@ class AscendAttentionMetadataBuilder: slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, - enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp, - is_only_prefill=common_attn_metadata.is_only_prefill) + enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp) return attn_metadata def build_for_graph_capture( diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 7ad54a2..711e291 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -17,14 +17,14 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from typing import Optional, Union +from typing import Optional import torch from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, CompilationLevel, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, from vllm.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, - init_metadata_for_sp) class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -100,7 +97,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): self, hidden_states, attn_metadata=None, - _metadata_for_padding: Optional[MetadataForPadding] = None, ): if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata @@ -119,7 +115,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): top_k=self.top_k, enable_force_load_balance=enable_force_load_balance, shared_experts=None, - _metadata_for_padding=_metadata_for_padding, ) return hidden_states @@ -188,60 +183,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.enable_sequence_parallelism = ( - vllm_config.compilation_config.pass_config. - enable_sequence_parallelism if vllm_config is not None else False) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> torch.Tensor: - - # To prevent precision issues during the decoder phase when only prefilling enables SP - if not self.enable_sequence_parallelism: - self.self_attn.o_proj.reduce_results = True - else: - self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True - - # Self Attention - if residual is None: - residual = hidden_states - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - residual = _metadata_for_padding.padding_slice(residual) - - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( - hidden_states) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if not self.use_aclgraph: - hidden_states = self.mlp( - hidden_states, _metadata_for_padding=_metadata_for_padding) - else: - hidden_states = self.mlp(hidden_states) - - return hidden_states, residual - @support_torch_compile class CustomQwen3MoeModel(Qwen3MoeModel): @@ -277,45 +218,6 @@ class CustomQwen3MoeModel(Qwen3MoeModel): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - _metadata_for_padding=_metadata_for_padding) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - return hidden_states - class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): @@ -340,7 +242,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism # Set MoE hyperparameters self.expert_weights: list[torch.Tensor] = [] @@ -361,16 +262,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - _metadata_for_padding = init_metadata_for_sp( - input_ids, self.enable_sequence_parallelism) - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, _metadata_for_padding) - return hidden_states diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 57beae2..554b40e 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -216,7 +216,9 @@ class AscendFusedMoE(FusedMoE): forward_context = get_forward_context() hidden_states, router_logits = forward_context.moe_comm_method.prepare( - hidden_states=hidden_states, router_logits=router_logits) + hidden_states=hidden_states, + router_logits=router_logits, + replace_allreduce=forward_context.sp_enabled) # Matrix multiply. final_hidden_states = self.quant_method.apply( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4fd4b1b..533c20b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -21,8 +21,7 @@ from typing import Any, Callable, Optional import torch import torch_npu from vllm.config import get_current_vllm_config -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_tp_group) from vllm.forward_context import get_forward_context @@ -42,7 +41,6 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method -from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, get_all_reduce_merge_state, get_rm_router_logits_state, is_310p, @@ -360,8 +358,7 @@ class AscendFusedMoE(FusedMoE): top_k: Optional[int] = None, shared_experts: Optional[Any] = None, gate=None, - replace_allreduce: bool = False, - _metadata_for_padding: Optional[MetadataForPadding] = None): + replace_allreduce: bool = False): assert self.quant_method is not None @@ -379,13 +376,7 @@ class AscendFusedMoE(FusedMoE): # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) - enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill - tp_size = get_tensor_model_parallel_world_size() - if enable_sp: - tp_rank = get_tensor_model_parallel_rank() - mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask - chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] + if forward_context.sp_enabled: replace_allreduce = True hidden_states, router_logits = forward_context.moe_comm_method.prepare( diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index cdb3b98..57044f5 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -48,8 +48,9 @@ from vllm.distributed.parallel_state import get_tp_group from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable, - mlp_tp_enable, oproj_tp_enable) +from vllm_ascend.utils import (dense_optim_enable, enable_sp, + matmul_allreduce_enable, mlp_tp_enable, + oproj_tp_enable) class CustomTensorParallelOp: @@ -82,10 +83,17 @@ class CustomTensorParallelOp: self.skip_bias_add = self.layer.skip_bias_add self.return_bias = self.layer.return_bias self.quant_method = self.layer.quant_method + self.prefix = self.layer.prefix + + def apply_impl(self, input_): + raise NotImplementedError # Replace layer.forward to customize the layer computation process. def apply(self, input_): - raise NotImplementedError + output, output_bias = self.apply_impl(input_) + if not self.return_bias: + return output + return output, output_bias class CustomColumnParallelOp(CustomTensorParallelOp): @@ -113,6 +121,14 @@ class CustomRowParallelOp(CustomTensorParallelOp): self.reduce_results = self.layer.reduce_results self.input_size_per_partition = self.layer.input_size_per_partition + def apply(self, input_): + output, output_bias = self.apply_impl(input_) + if dense_optim_enable(): + torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) + if not self.return_bias: + return output + return output, output_bias + class MLPColumnParallelOp(CustomColumnParallelOp): @@ -123,7 +139,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp): def comm_group(self): return get_mlp_tp_group() - def apply( + def apply_impl( self, input_: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -134,14 +150,12 @@ class MLPColumnParallelOp(CustomColumnParallelOp): output = self.quant_method.apply(self.layer, input_parallel, bias) output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias -class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp): +class SequenceMergedColumnParallelOp(CustomColumnParallelOp): - def apply( + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: """Linear layer with column parallelism. @@ -164,18 +178,16 @@ class DenseOptimMergedColumnParallelOp(CustomColumnParallelOp): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias -class DenseOptimQKVParallelOp(CustomColumnParallelOp): +class SequenceQKVParallelOp(CustomColumnParallelOp): def __init__(self, layer, prefix): super().__init__(layer) self.prefix = prefix - def apply( + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: """Linear layer with column parallelism. @@ -201,8 +213,6 @@ class DenseOptimQKVParallelOp(CustomColumnParallelOp): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias @@ -215,7 +225,7 @@ class MLPRowParallelOp(CustomRowParallelOp): def comm_group(self): return get_mlp_tp_group() - def apply( + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: @@ -234,8 +244,6 @@ class MLPRowParallelOp(CustomRowParallelOp): output = self.comm_group.reduce_scatter(output_parallel, 0) output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias @@ -248,7 +256,7 @@ class OProjRowParallelOp(CustomRowParallelOp): def comm_group(self): return get_otp_group() - def apply( + def apply_impl( self, input_: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -294,8 +302,6 @@ class OProjRowParallelOp(CustomRowParallelOp): # Handle bias return based on configuration output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias def update_attrs(self): @@ -311,7 +317,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp): super().__init__(layer) self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) - def apply( + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: @@ -335,8 +341,6 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp): bias=bias_) output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output return output, output_bias @classmethod @@ -359,13 +363,13 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp): self.weight_t = self.layer.weight.t() -class DenseOptimRowParallelOp(CustomRowParallelOp): +class SequenceRowParallelOp(CustomRowParallelOp): def __init__(self, layer, prefix): super().__init__(layer) self.prefix = prefix - def apply( + def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: """Linear layer with column parallelism. @@ -391,12 +395,8 @@ class DenseOptimRowParallelOp(CustomRowParallelOp): input_parallel, bias=bias_) output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) - torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix) output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output return output, output_bias def update_attrs(self): @@ -407,23 +407,22 @@ class DenseOptimRowParallelOp(CustomRowParallelOp): def get_column_parallel_op( disable_tp, prefix, layer -) -> Tuple[ - Optional[Union[MLPColumnParallelOp, DenseOptimMergedColumnParallelOp, - DenseOptimQKVParallelOp]], int, int]: +) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, + SequenceQKVParallelOp]], int, int]: if disable_tp: return None, 0, 1 custom_op: Optional[Union[ MLPColumnParallelOp, - DenseOptimMergedColumnParallelOp, - DenseOptimQKVParallelOp, + SequenceMergedColumnParallelOp, + SequenceQKVParallelOp, ]] = None if "gate_up_proj" in prefix and mlp_tp_enable(): custom_op = MLPColumnParallelOp(layer) - elif "gate_up_proj" in prefix and dense_optim_enable(): - custom_op = DenseOptimMergedColumnParallelOp(layer) - elif dense_optim_enable(): - custom_op = DenseOptimQKVParallelOp(layer, prefix) + elif "gate_up_proj" in prefix and enable_sp(): + custom_op = SequenceMergedColumnParallelOp(layer) + elif enable_sp(): + custom_op = SequenceQKVParallelOp(layer, prefix) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size @@ -435,21 +434,21 @@ def get_row_parallel_op( disable_tp, prefix, layer ) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, - DenseOptimRowParallelOp]], int, int]: + SequenceRowParallelOp]], int, int]: if disable_tp: return None, 0, 1 custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, - DenseOptimRowParallelOp]] = None + SequenceRowParallelOp]] = None if "down_proj" in prefix and mlp_tp_enable(): custom_op = MLPRowParallelOp(layer) elif "o_proj" in prefix and oproj_tp_enable(): custom_op = OProjRowParallelOp(layer) elif matmul_allreduce_enable(): custom_op = MatmulAllreduceRowParallelOp(layer) - elif dense_optim_enable(): - custom_op = DenseOptimRowParallelOp(layer, prefix) + elif enable_sp(): + custom_op = SequenceRowParallelOp(layer, prefix) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index bc0d4fb..1d6df2c 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -133,11 +133,15 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp + forward_context = get_forward_context() + mc2_mask = forward_context.mc2_mask + if self.tp_size > 1: + # Also slice mc2_mask + split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) + mc2_mask = split_mc2_mask[self.tp_rank] if not self.replace_allreduce: self.num_tokens, _ = hidden_states.shape - forward_context = get_forward_context() - mc2_mask = forward_context.mc2_mask target_pad_length = forward_context.padded_num_tokens pad_size = target_pad_length - self.num_tokens @@ -149,23 +153,16 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize): (0, 0, 0, pad_size)) # Slice across TP ranks - if self.tp_size > 1: - if not self.enable_shared_expert_dp: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) - hidden_states = split_hidden_states[self.tp_rank] - router_logits = split_router_logits[self.tp_rank] - self.split_hidden_states = split_hidden_states # Save for finalize - - # Also slice mc2_mask - split_mc2_mask = torch.tensor_split(mc2_mask, - self.tp_size, - dim=0) - mc2_mask = split_mc2_mask[self.tp_rank] + if self.tp_size > 1 and not self.enable_shared_expert_dp: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + self.split_hidden_states = split_hidden_states # Save for finalize return hidden_states, router_logits, mc2_mask diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 1267066..a702b35 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -20,10 +20,9 @@ def _maybe_chunk_residual_impl(x: torch.Tensor, return residual if x.size(0) != residual.size(0): - flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled - assert flashcomm_v1_enabled is True, ( - "Currently, this situation only occurs " - "when flashcomm_v1 is enabled") + sp_enabled = forward_context.sp_enabled + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") pad_size = forward_context.pad_size if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) @@ -41,8 +40,8 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, except AssertionError: return x - flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled - if flashcomm_v1_enabled and label: + sp_enabled = forward_context.sp_enabled + if sp_enabled and label: x = tensor_model_parallel_all_gather(x, 0) pad_size = forward_context.pad_size if pad_size > 0: @@ -56,8 +55,8 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: except AssertionError: return tensor_model_parallel_all_reduce(x) - flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled - if flashcomm_v1_enabled: + sp_enabled = forward_context.sp_enabled + if sp_enabled: pad_size = forward_context.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4bd29b1..f00abca 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -282,12 +282,6 @@ class NPUPlatform(Platform): ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config - if compilation_config.pass_config.enable_sequence_parallelism: - if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe": - raise NotImplementedError( - "For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV." - ) - @classmethod def get_attn_backend_cls(cls, selected_backend, diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index eaed918..8093ad4 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -54,8 +54,8 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, - init_metadata_for_sp) +from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, + init_metadata_for_sp) class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): diff --git a/vllm_ascend/ops/sequence_parallel.py b/vllm_ascend/torchair/ops/sequence_parallel.py similarity index 100% rename from vllm_ascend/ops/sequence_parallel.py rename to vllm_ascend/torchair/ops/sequence_parallel.py diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 0c85c85..967aa03 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -44,8 +44,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.sequence_parallel import MetadataForPadding from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod +from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 2b2b540..570756f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -590,6 +590,14 @@ def dense_optim_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE +def enable_sp() -> bool: + from vllm.config import get_cached_compilation_config + + return ( + get_cached_compilation_config().pass_config.enable_sequence_parallelism + or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) + + def is_moe_model(vllm_config: VllmConfig): config = vllm_config.model_config.hf_config return any('experts' in key.lower() for key in config.to_dict()) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9c33059..98aeac6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1582,7 +1582,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): update_attn_params(self.update_stream, forward_context, positions.shape[0]) - if get_forward_context().flashcomm_v1_enabled: + if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) pad_size = get_forward_context().pad_size if pad_size > 0: