From 2ad0ca52a6eb391816ad6c444cd57968c9e13e77 Mon Sep 17 00:00:00 2001 From: wangbj127 <256472688+wangbj127@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:09:33 +0800 Subject: [PATCH] Qwen3.5 MoE supports flashcomm v1 (#7644) cherry pick from https://github.com/vllm-project/vllm-ascend/pull/7486 ### What this PR does / why we need it? Multimodal models like Qwen3.5 MoE does embedding in model_runner, so when flash comm is enabled, the first AllGather operation should be skipped. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: Wangbingjie Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com> --- tests/e2e/multicard/4-cards/test_qwen3_5.py | 32 ++++- vllm_ascend/ops/layernorm.py | 1 + vllm_ascend/ops/linear_op.py | 6 +- vllm_ascend/ops/register_custom_ops.py | 4 +- vllm_ascend/patch/worker/patch_qwen3_5.py | 126 +++++++++++++++++++- vllm_ascend/spec_decode/eagle_proposer.py | 15 ++- vllm_ascend/utils.py | 6 +- 7 files changed, 182 insertions(+), 8 deletions(-) diff --git a/tests/e2e/multicard/4-cards/test_qwen3_5.py b/tests/e2e/multicard/4-cards/test_qwen3_5.py index 475086a0..f6008d98 100644 --- a/tests/e2e/multicard/4-cards/test_qwen3_5.py +++ b/tests/e2e/multicard/4-cards/test_qwen3_5.py @@ -16,7 +16,9 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/basic_correctness/test_basic_correctness.py # +import os from tests.e2e.conftest import VllmRunner +from unittest.mock import patch def test_qwen3_5_27b_distributed_mp_tp4(): @@ -72,4 +74,32 @@ def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3(): "num_speculative_tokens": 3, }) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model \ No newline at end of file + del vllm_model + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3_flashcomm(): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + max_tokens = 20 + with VllmRunner("Qwen/Qwen3.5-35B-A3B", + tensor_parallel_size=4, + enable_expert_parallel=True, + max_model_len=4096, + gpu_memory_utilization=0.90, + distributed_executor_backend="mp", + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 12, 16], + }, + speculative_config={ + "method": "qwen3_5_mtp", + "num_speculative_tokens": 3, + }) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 71ba7e73..2c3a8b0f 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -97,6 +97,7 @@ class AscendGemmaRMSNorm(GemmaRMSNorm): import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) if enable_custom_op(): x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias( x, residual, 1.0 + self.weight, None, self.variance_epsilon diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 6ba9c462..85d29d3f 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -57,6 +57,7 @@ from vllm.distributed import ( tensor_model_parallel_reduce_scatter, ) from vllm.distributed.parallel_state import get_tp_group +from vllm.model_executor.models.utils import extract_layer_index from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX @@ -74,6 +75,7 @@ from vllm_ascend.utils import ( flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, get_weight_prefetch_method, + is_vl_model, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, @@ -430,8 +432,8 @@ class SequenceColumnParallelOp(CustomColumnParallelOp): # Matrix multiply. assert self.quant_method is not None - - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + need_all_gather = not (extract_layer_index(self.layer.prefix) == 0 and is_vl_model() and "attn" in self.prefix) + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, label=need_all_gather) output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output: diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index c500fb83..dea1e0bf 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -17,7 +17,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.rotary_embedding import rope_forward_oot from vllm_ascend.ops.triton.muls_add import muls_add_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream +from vllm_ascend.utils import enable_sp_by_pass, is_vl_model, npu_stream_switch, prefetch_stream def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: @@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor enable_sp_by_pass() and is_ep_comm ) - if not flash_comm_v1_enabled: + if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()): return tensor_model_parallel_all_reduce(x) dp_metadata = forward_context.dp_metadata diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index d19d4ade..3c78d2f9 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -18,15 +18,18 @@ import torch +from einops import rearrange +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update -from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet +from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet from vllm.model_executor.models.qwen3_next import Qwen3NextAttention from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch @@ -41,6 +44,64 @@ def to_int64_tuple(t): class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) + num_tokens = mixed_qkvz.size(0) + qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size + z_size = self.value_dim // self.tp_size + mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) + z = z.reshape(z.size(0), -1, self.head_v_dim) + ba, _ = self.in_proj_ba(hidden_states) + b, a = ba.chunk(2, dim=-1) + + b = b.contiguous() + a = a.contiguous() + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + # Note: we should not use torch.empty here like other attention backends, + # see discussions in https://github.com/vllm-project/vllm/pull/28182 + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops.vllm.gdn_attention_core( + mixed_qkv, + b, + a, + core_attn_out, + self.prefix, + ) + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + o_out, _ = self.out_proj(core_attn_out) + actual_num_tokens = o_out.shape[0] + output[:actual_num_tokens] = o_out + def _forward_core( self, mixed_qkv: torch.Tensor, @@ -320,5 +381,68 @@ class AscendQwen3NextAttention(Qwen3NextAttention): output[:], _ = self.o_proj(attn_output) +class AscendQwen3_5DecoderLayer(Qwen3_5DecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + if self.layer_idx == 0 and _EXTRA_CTX.flash_comm_v1_enabled: + tp_size = get_tensor_model_parallel_world_size() + n_out = (hidden_states.shape[0] + tp_size - 1) // tp_size + hidden_dim = hidden_states.shape[-1] + self_attention_output = torch.empty( + (n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + else: + self_attention_output = torch.empty_like(hidden_states) + + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype) + 1) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( + f"shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}" + ) + hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype) + 1) + + return hidden_states, residual + + +Qwen3_5GatedDeltaNet.forward = AscendQwen3_5GatedDeltaNet.forward Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core Qwen3NextAttention.forward = AscendQwen3NextAttention.forward +Qwen3_5DecoderLayer.forward = AscendQwen3_5DecoderLayer.forward diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 72ee9d9c..4da0baff 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -155,6 +155,7 @@ class SpecDecodeBaseProposer(EagleProposer): ] self._runnable = self._run_merged_draft + self.is_multimodal_model = self.vllm_config.model_config.is_multimodal_model if self.uses_mrope: self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device) elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: @@ -414,6 +415,16 @@ class SpecDecodeBaseProposer(EagleProposer): if is_profile: batch_size = min(batch_size, self.runner.max_num_reqs) + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = (None, None) + inputs_embeds = self.model.embed_input_ids( + self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed + ) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + inputs_embeds = None + with set_ascend_forward_context( multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, self.vllm_config, @@ -437,7 +448,7 @@ class SpecDecodeBaseProposer(EagleProposer): token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request], # The target_position's address is same as the model_positions's target_positions=model_positions, - inputs_embeds=None, + inputs_embeds=inputs_embeds, multi_steps_attn_metadata=multi_steps_attn_metadata, num_tokens=num_tokens, ) @@ -1581,6 +1592,8 @@ class SpecDecodeBaseProposer(EagleProposer): hidden_states: torch.Tensor, positions: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled: + return hidden_states, positions if self.method == "mtp": if _EXTRA_CTX.flash_comm_v1_enabled: hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 4db163c7..2c35788c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -849,7 +849,7 @@ def _is_contain_expert(config: Any): return False -def is_vl_model(vllm_config: VllmConfig): +def is_vl_model(vllm_config: VllmConfig = None): """Checks if the model is a VL model by config. Uses the same criterion as vllm itself (model_config.py): a model is @@ -859,6 +859,10 @@ def is_vl_model(vllm_config: VllmConfig): self (rare but possible). """ global _IS_VL_MODEL + if vllm_config is None: + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config: model_config = vllm_config.model_config # Primary: vllm's own VL detection — hf_config is the top-level