[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -81,6 +81,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
[0])
|
||||
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
|
||||
[0])
|
||||
attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx = torch.tensor(
|
||||
[0, 1])
|
||||
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
|
||||
|
||||
output, attn_lse = self.impl._forward_prefill_cp(
|
||||
query, key, value, attn_metadata)
|
||||
@@ -257,12 +260,23 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
attn_metadata.prefill = MagicMock()
|
||||
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
|
||||
[0, 3, 1, 2, 0, 0, 0, 0])
|
||||
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
|
||||
attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla = 0
|
||||
attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx = torch.arange(
|
||||
num_tokens * 3 * self.impl.pcp_size
|
||||
)
|
||||
attn_metadata.prefill.pcp_metadata.pcp_unpad_mask = torch.tensor(
|
||||
[True, False, True, True, True, True, True, True]
|
||||
)
|
||||
|
||||
query = torch.rand(num_tokens, num_heads, head_size)
|
||||
key = torch.randn(num_tokens, num_heads, head_size)
|
||||
value = torch.randn(num_tokens, num_heads, head_size)
|
||||
output = torch.rand(num_tokens, num_heads * head_size)
|
||||
|
||||
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
query, key, value, output = self.impl.reshape_and_cache(
|
||||
query, key, value, kv_cache, attn_metadata, output
|
||||
)
|
||||
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
|
||||
self.assertEqual(key.shape[1], num_heads)
|
||||
self.assertEqual(key.shape[2], head_size)
|
||||
|
||||
Reference in New Issue
Block a user