[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -13,6 +13,8 @@ import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from vllm.distributed import get_pcp_group
from vllm.forward_context import get_forward_context
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
@@ -96,6 +98,14 @@ def causal_conv1d_fn(
indices 0 and 3
out: (batch, dim, seqlen)
"""
forward_context = get_forward_context()
num_decodes = 0
attn_metadata = forward_context.attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()), None)
if attn_metadata is not None:
num_decodes = attn_metadata.num_decodes
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
@@ -108,6 +118,13 @@ def causal_conv1d_fn(
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
width = weight.shape[1]
last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1])
if get_pcp_group().world_size > 1:
all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0)
pcp_rank = get_pcp_group().rank_in_group
if pcp_rank > 0:
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...]
for i in range(len(seqlens)):
x_s = splits[i]
@@ -121,14 +138,25 @@ def causal_conv1d_fn(
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None,
initial_states=conv_states[cache_indices[i]][..., : (width - 1)],
)
)
if get_pcp_group().world_size > 1:
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...]
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
def extract_last_width(x, start_loc, width):
end_loc = start_loc[1:]
offsets = torch.arange(width, device=x.device)
indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width)
return x[:, indices].permute(1, 0, 2)
@triton.jit
def _causal_conv1d_update_kernel_npu_tiled(
# Pointers