[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
h_update,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
@@ -72,6 +73,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
|
||||
b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32)
|
||||
b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32)
|
||||
# create b_hupd_bv1 and b_hupd_bv2
|
||||
|
||||
v_start1 = 0
|
||||
v_start2 = 64
|
||||
@@ -204,6 +206,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
h_update = k.new_empty(B, NT, H, K, K)
|
||||
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
@@ -223,6 +226,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
h_update=h_update,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
|
||||
Reference in New Issue
Block a user