[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -20,6 +20,7 @@ from typing import ClassVar
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from vllm.config import VllmConfig
from vllm.distributed import (
@@ -209,7 +210,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_use_hybrid_attn=common_long_seq_metadata.pcp_use_hybrid_attn,
pcp_unpad_mask=common_long_seq_metadata.pcp_unpad_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
pcp_fa_query_idx=common_long_seq_metadata.pcp_fa_query_idx,
pcp_padded_tokens_fla=common_long_seq_metadata.pcp_padded_tokens_fla,
pcp_enter_fa_restore_idx=common_long_seq_metadata.pcp_enter_fa_restore_idx,
)
prefill_metadata = AscendMetadataForPrefill(
@@ -469,6 +475,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
query = torch.index_select(query, 0, fa_query_idx)
q_head = torch.index_select(query, 0, q_head_idx)
q_tail = torch.index_select(query, 0, q_tail_idx)
k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None
@@ -735,14 +745,18 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
def reshape_and_cache(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor,
):
num_tokens = query.shape[0]
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
output_padded = output
if len(kv_cache) > 1:
if self.is_kv_producer:
@@ -762,14 +776,23 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
key, value = all_kv.split([self.head_size, self.head_size], dim=-1)
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
if not attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
kv = torch.cat([key, value], dim=-1)
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
key, value = all_kv.split([self.head_size, self.head_size], dim=-1)
else:
query, key, value = self._gather_and_restore_pcp_qkv(query, key, value, attn_metadata)
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded
output_local_padded_tokens_fa = num_actual_tokens_pcp_padded // self.pcp_size - num_tokens
if output_local_padded_tokens_fa > 0:
output_padded = F.pad(
output, pad=(0, 0, 0, 0, 0, output_local_padded_tokens_fa), mode="constant", value=0
)
prefill_key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded]
prefill_value = value[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded]
slot_mapping = attn_metadata.slot_mapping[
@@ -784,7 +807,62 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
return query, key, value, output_padded
def _gather_and_restore_pcp_qkv(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
):
"""
Gathers QKV chunks from all GPUs in the PCP group and restores the original
sequence order for Context Parallelism (CP).
"""
num_tokens = query.shape[0]
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
pcp_padded_tokens_fla = attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla
num_tokens_pcp_padded_fla = num_tokens + pcp_padded_tokens_fla
qkv_fla = torch.cat(
[query.reshape(num_tokens, -1), key.reshape(num_tokens, -1), value.reshape(num_tokens, -1)],
dim=-1,
)
if pcp_padded_tokens_fla > 0:
qkv_fla = F.pad(qkv_fla, pad=(0, 0, 0, pcp_padded_tokens_fla), mode="constant", value=0)
all_qkv = get_pcp_group().all_gather(qkv_fla[:num_tokens_pcp_padded_fla].contiguous(), dim=0)
# Restore the original sequence order using pre-computed indices
pcp_enter_fa_restore_idx = (
attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx if attn_metadata.prefill.pcp_metadata else None
)
actual_qkv = torch.index_select(all_qkv, 0, pcp_enter_fa_restore_idx)
qkv_fa_padding_workspace = query.new_empty(
(num_actual_tokens_pcp_padded, (self.num_heads + 2 * self.num_kv_heads) * self.head_size)
)
decode_offset = attn_metadata.num_decode_tokens * self.pcp_size
qkv_fa_padding_workspace[:decode_offset] = actual_qkv[:decode_offset]
pcp_unpad_mask = attn_metadata.prefill.pcp_metadata.pcp_unpad_mask[attn_metadata.num_decodes * self.pcp_size :]
qkv_fa_padding_workspace[decode_offset:][pcp_unpad_mask] = actual_qkv[decode_offset:]
q, k, v = qkv_fa_padding_workspace.split(
[
self.num_heads * self.head_size,
self.num_kv_heads * self.head_size,
self.num_kv_heads * self.head_size,
],
dim=-1,
)
return (
q.reshape(-1, self.num_heads, self.head_size),
k.reshape(-1, self.num_kv_heads, self.head_size),
v.reshape(-1, self.num_kv_heads, self.head_size),
)
def _gather_global_context_output(self, local_context_attn_output):
if self.dcp_size > 1:
@@ -831,8 +909,15 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
pcp_use_hybrid_attn = False
if has_prefill:
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
pcp_use_hybrid_attn = attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn
if has_decode:
decode_query = query[:num_decode_tokens]
if pcp_use_hybrid_attn:
decode_query = query[: num_decode_tokens * self.pcp_size : self.pcp_size].contiguous()
else:
decode_query = query[:num_decode_tokens].contiguous()
output_decode = self._forward_decode_pcp_dcp(decode_query, attn_metadata)
output[:num_decode_tokens] = output_decode
if has_prefill:
@@ -849,7 +934,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
# qkv init
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
if pcp_use_hybrid_attn:
prefill_query = query[self.pcp_size * num_decode_tokens :]
else:
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
key = key[self.pcp_size * num_decode_tokens :].contiguous()
value = value[self.pcp_size * num_decode_tokens :].contiguous()
@@ -914,5 +1002,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
attn_output_prefill, attn_lse_prefill, context_output, context_lse, prefill_query, attn_metadata
)
if self.pcp_size > 1 and pcp_use_hybrid_attn:
# layer_idx != num_layers - 1
assert attn_metadata.prefill.pcp_metadata is not None
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx)
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)
output[num_decode_tokens : attn_output_prefill.shape[0] + num_decode_tokens] = attn_output_prefill
return output

View File

@@ -25,7 +25,12 @@ class AscendPCPMetadata:
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_use_hybrid_attn: bool = False
pcp_unpad_mask: torch.Tensor = None
pcp_allgather_restore_idx: list[int] | None = None
pcp_fa_query_idx: torch.Tensor = None
pcp_padded_tokens_fla: int = 0
pcp_enter_fa_restore_idx: torch.Tensor = None
block_table_cp: torch.Tensor = None
valid_block_ids: torch.Tensor = None
prefill_q_cum_seqlens: torch.Tensor = None