[Bugfix] Qwen3Next support FlashComm1 (#6830)

### What this PR does / why we need it?
Support FlashComm1 for Qwen3-Next. Fix some padding problems in Sequence
Parallel (SP)
and resolve precision problems in shared_out when both FlashComm1 is
enabled.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: zhaojiangjiang <zhaojiangjiang1@h-partners.com>
Co-authored-by: zhaojiangjiang <zhaojiangjiang1@h-partners.com>
This commit is contained in:
ZhaoJiangJiang
2026-03-06 17:14:08 +08:00
committed by GitHub
parent a2696006d1
commit a51d6366b9
4 changed files with 63 additions and 8 deletions

View File

@@ -30,6 +30,7 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.triton.fla.fused_qkvzba_split_reshape import fused_qkvzba_split_reshape_cat
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
from vllm_ascend.utils import enable_sp
class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
@@ -44,13 +45,13 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
num_tokens = projected_states_qkvz.size(0)
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
projected_states_qkvz,
@@ -126,9 +127,10 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
if not enable_sp():
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
# 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
@@ -292,11 +294,20 @@ class AscendQwen3Next_GatedDeltaNet(nn.Module, MambaBase):
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
if not enable_sp():
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens]
elif spec_sequence_masks is not None:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens]
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens]
Qwen3NextGatedDeltaNet.forward = AscendQwen3Next_GatedDeltaNet.forward