[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -12,15 +12,19 @@ import warnings
import torch
from einops import rearrange
from vllm.distributed import get_pcp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_delta_hupdate import chunk_gated_delta_rule_fwd_hupdate
from .chunk_o import chunk_fwd_o
from .chunk_o_update import chunk_fwd_o_update
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import input_guard
from .utils import input_guard, prepare_final_chunk_indices
from .wy_fast import recompute_w_u_fwd
@@ -35,7 +39,15 @@ def chunk_gated_delta_rule_fwd(
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
forward_context = get_forward_context()
num_decodes = 0
attn_metadata = forward_context.attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is not None:
num_decodes = attn_metadata.num_decodes
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
@@ -56,6 +68,45 @@ def chunk_gated_delta_rule_fwd(
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
if get_pcp_group().world_size > 1:
h_update = chunk_gated_delta_rule_fwd_hupdate(
k=k,
w=w,
u=u,
g=g,
cu_seqlens=cu_seqlens,
num_decodes=num_decodes,
)
all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0)
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
final_h_update = h_update[:, final_chunk_indices, :, :, :]
all_final_h_update = get_pcp_group().all_gather(final_h_update, 0)
updated_state = final_state.new_empty(get_pcp_group().world_size, *final_state.shape)
updated_state[0, ...] = all_final_state[0]
for i in range(1, get_pcp_group().world_size):
updated_final_state = all_final_state[i] + torch.matmul(
all_final_h_update[i, ...], updated_state[i - 1, ...]
)
updated_state[i, ...] = updated_final_state
final_state = updated_state[-1, ...]
if get_pcp_group().rank_in_group == 0:
updated_h_state = torch.zeros_like(final_state)
else:
updated_h_state = updated_state[get_pcp_group().rank_in_group - 1, ...]
h = chunk_fwd_o_update(
q=q,
v=v_new,
h=h,
h_update=h_update,
updated_h_state=updated_h_state,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
@@ -65,6 +116,7 @@ def chunk_gated_delta_rule_fwd(
scale=scale,
cu_seqlens=cu_seqlens,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
@@ -90,7 +142,6 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,