[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -377,6 +377,15 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
|
||||
|
||||
self.num_tokens_pcp = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_pcp - self.num_tokens_pcp
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
|
||||
Reference in New Issue
Block a user