[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -12,15 +12,19 @@ import warnings
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_delta_hupdate import chunk_gated_delta_rule_fwd_hupdate
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_o_update import chunk_fwd_o_update
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import input_guard
|
||||
from .utils import input_guard, prepare_final_chunk_indices
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
@@ -35,7 +39,15 @@ def chunk_gated_delta_rule_fwd(
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
forward_context = get_forward_context()
|
||||
num_decodes = 0
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = next(iter(attn_metadata.values()), None)
|
||||
if attn_metadata is not None:
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
@@ -56,6 +68,45 @@ def chunk_gated_delta_rule_fwd(
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
if get_pcp_group().world_size > 1:
|
||||
h_update = chunk_gated_delta_rule_fwd_hupdate(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
num_decodes=num_decodes,
|
||||
)
|
||||
all_final_state = get_pcp_group().all_gather(final_state.unsqueeze(0), 0)
|
||||
final_chunk_indices = prepare_final_chunk_indices(cu_seqlens, chunk_size)
|
||||
final_h_update = h_update[:, final_chunk_indices, :, :, :]
|
||||
all_final_h_update = get_pcp_group().all_gather(final_h_update, 0)
|
||||
|
||||
updated_state = final_state.new_empty(get_pcp_group().world_size, *final_state.shape)
|
||||
updated_state[0, ...] = all_final_state[0]
|
||||
for i in range(1, get_pcp_group().world_size):
|
||||
updated_final_state = all_final_state[i] + torch.matmul(
|
||||
all_final_h_update[i, ...], updated_state[i - 1, ...]
|
||||
)
|
||||
updated_state[i, ...] = updated_final_state
|
||||
|
||||
final_state = updated_state[-1, ...]
|
||||
|
||||
if get_pcp_group().rank_in_group == 0:
|
||||
updated_h_state = torch.zeros_like(final_state)
|
||||
else:
|
||||
updated_h_state = updated_state[get_pcp_group().rank_in_group - 1, ...]
|
||||
|
||||
h = chunk_fwd_o_update(
|
||||
q=q,
|
||||
v=v_new,
|
||||
h=h,
|
||||
h_update=h_update,
|
||||
updated_h_state=updated_h_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
@@ -65,6 +116,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
@@ -90,7 +142,6 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
|
||||
Reference in New Issue
Block a user