[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -81,6 +81,9 @@ class TestAscendAttentionCPImpl(TestBase):
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx = torch.tensor(
[0, 1])
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
output, attn_lse = self.impl._forward_prefill_cp(
query, key, value, attn_metadata)
@@ -257,12 +260,23 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla = 0
attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx = torch.arange(
num_tokens * 3 * self.impl.pcp_size
)
attn_metadata.prefill.pcp_metadata.pcp_unpad_mask = torch.tensor(
[True, False, True, True, True, True, True, True]
)
query = torch.rand(num_tokens, num_heads, head_size)
key = torch.randn(num_tokens, num_heads, head_size)
value = torch.randn(num_tokens, num_heads, head_size)
output = torch.rand(num_tokens, num_heads * head_size)
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
attn_metadata)
query, key, value, output = self.impl.reshape_and_cache(
query, key, value, kv_cache, attn_metadata, output
)
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)