[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -13,6 +13,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore
|
||||
|
||||
|
||||
@@ -96,6 +98,14 @@ def causal_conv1d_fn(
|
||||
indices 0 and 3
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
num_decodes = 0
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = next(iter(attn_metadata.values()), None)
|
||||
if attn_metadata is not None:
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
@@ -108,6 +118,13 @@ def causal_conv1d_fn(
|
||||
seqlens = seqlens.tolist()
|
||||
splits = torch.split(x, seqlens, dim=-1)
|
||||
width = weight.shape[1]
|
||||
last_width_prefill_x = extract_last_width(x, query_start_loc[num_decodes:], conv_states.shape[-1])
|
||||
|
||||
if get_pcp_group().world_size > 1:
|
||||
all_last_width_prefill_x = get_pcp_group().all_gather(last_width_prefill_x.unsqueeze(0).contiguous(), 0)
|
||||
pcp_rank = get_pcp_group().rank_in_group
|
||||
if pcp_rank > 0:
|
||||
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[pcp_rank - 1, ...]
|
||||
|
||||
for i in range(len(seqlens)):
|
||||
x_s = splits[i]
|
||||
@@ -121,14 +138,25 @@ def causal_conv1d_fn(
|
||||
activation=activation,
|
||||
return_final_states=True,
|
||||
final_states_out=conv_states[cache_indices[i]][..., : (width - 1)].unsqueeze(0),
|
||||
initial_states=conv_states[cache_indices[i]][..., : (width - 1)] if has_initial_state[i] else None,
|
||||
initial_states=conv_states[cache_indices[i]][..., : (width - 1)],
|
||||
)
|
||||
)
|
||||
|
||||
if get_pcp_group().world_size > 1:
|
||||
conv_states[cache_indices[num_decodes:]] = all_last_width_prefill_x[-1, ...]
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
return out_ref_tensor
|
||||
|
||||
|
||||
def extract_last_width(x, start_loc, width):
|
||||
end_loc = start_loc[1:]
|
||||
offsets = torch.arange(width, device=x.device)
|
||||
indices = end_loc.unsqueeze(1) - width + offsets.unsqueeze(0) # (num_seqs, width)
|
||||
|
||||
return x[:, indices].permute(1, 0, 2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _causal_conv1d_update_kernel_npu_tiled(
|
||||
# Pointers
|
||||
|
||||
Reference in New Issue
Block a user