[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -44,16 +44,15 @@ def test_models_pcp_dcp_basic():
runner.model.generate(prompts, sampling_params)
model = "vllm-ascend/Qwen3-30B-A3B-W8A8"
with VllmRunner(
model,
enforce_eager=True,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=1,
enable_expert_parallel=True,
block_size=128,
quantization="ascend",
with VllmRunner(model,
enforce_eager=True,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=1,
enable_expert_parallel=True,
block_size=128,
quantization="ascend",
) as runner:
runner.model.generate(prompts, sampling_params)
@@ -71,6 +70,19 @@ def test_models_pcp_dcp_basic():
) as runner:
runner.model.generate(prompts, sampling_params)
model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
with VllmRunner(model,
enforce_eager=True,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=1,
max_num_batched_tokens=1024,
enable_expert_parallel=True,
gpu_memory_utilization=0.8,
block_size=128) as runner:
runner.model.generate(prompts, sampling_params)
def test_models_pcp_dcp_full_graph():
prompts = [

View File

@@ -81,6 +81,9 @@ class TestAscendAttentionCPImpl(TestBase):
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx = torch.tensor(
[0, 1])
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
output, attn_lse = self.impl._forward_prefill_cp(
query, key, value, attn_metadata)
@@ -257,12 +260,23 @@ class TestAscendAttentionCPImpl(TestBase):
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])
attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn = False
attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla = 0
attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx = torch.arange(
num_tokens * 3 * self.impl.pcp_size
)
attn_metadata.prefill.pcp_metadata.pcp_unpad_mask = torch.tensor(
[True, False, True, True, True, True, True, True]
)
query = torch.rand(num_tokens, num_heads, head_size)
key = torch.randn(num_tokens, num_heads, head_size)
value = torch.randn(num_tokens, num_heads, head_size)
output = torch.rand(num_tokens, num_heads * head_size)
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
attn_metadata)
query, key, value, output = self.impl.reshape_and_cache(
query, key, value, kv_cache, attn_metadata, output
)
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)