[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -892,10 +892,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def reshape_and_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
if len(kv_cache) > 1:
|
||||
if self.is_kv_producer:
|
||||
@@ -915,7 +917,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
return query, key, value, output
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
@@ -970,12 +972,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
output_padded = None
|
||||
if key is not None and value is not None:
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata)
|
||||
output_padded = output
|
||||
query, key, value, output_padded = self.reshape_and_cache(
|
||||
query, key, value, kv_cache, attn_metadata, output
|
||||
)
|
||||
# pooling model branch
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
if output_padded is not None:
|
||||
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
|
||||
else:
|
||||
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import ClassVar
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
@@ -209,7 +210,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_use_hybrid_attn=common_long_seq_metadata.pcp_use_hybrid_attn,
|
||||
pcp_unpad_mask=common_long_seq_metadata.pcp_unpad_mask,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
|
||||
pcp_fa_query_idx=common_long_seq_metadata.pcp_fa_query_idx,
|
||||
pcp_padded_tokens_fla=common_long_seq_metadata.pcp_padded_tokens_fla,
|
||||
pcp_enter_fa_restore_idx=common_long_seq_metadata.pcp_enter_fa_restore_idx,
|
||||
)
|
||||
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
@@ -469,6 +475,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
|
||||
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
|
||||
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
|
||||
if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
|
||||
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
|
||||
query = torch.index_select(query, 0, fa_query_idx)
|
||||
|
||||
q_head = torch.index_select(query, 0, q_head_idx)
|
||||
q_tail = torch.index_select(query, 0, q_tail_idx)
|
||||
k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None
|
||||
@@ -735,14 +745,18 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
|
||||
def reshape_and_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
num_tokens = query.shape[0]
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
output_padded = output
|
||||
|
||||
if len(kv_cache) > 1:
|
||||
if self.is_kv_producer:
|
||||
@@ -762,14 +776,23 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
|
||||
if has_prefill:
|
||||
if self.pcp_size > 1:
|
||||
kv = torch.cat([key, value], dim=-1)
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
||||
assert attn_metadata.prefill is not None
|
||||
assert attn_metadata.prefill.pcp_metadata is not None
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
||||
all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
|
||||
key, value = all_kv.split([self.head_size, self.head_size], dim=-1)
|
||||
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
|
||||
if not attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
|
||||
kv = torch.cat([key, value], dim=-1)
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
||||
all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
|
||||
key, value = all_kv.split([self.head_size, self.head_size], dim=-1)
|
||||
else:
|
||||
query, key, value = self._gather_and_restore_pcp_qkv(query, key, value, attn_metadata)
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded
|
||||
output_local_padded_tokens_fa = num_actual_tokens_pcp_padded // self.pcp_size - num_tokens
|
||||
if output_local_padded_tokens_fa > 0:
|
||||
output_padded = F.pad(
|
||||
output, pad=(0, 0, 0, 0, 0, output_local_padded_tokens_fa), mode="constant", value=0
|
||||
)
|
||||
|
||||
prefill_key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded]
|
||||
prefill_value = value[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded]
|
||||
slot_mapping = attn_metadata.slot_mapping[
|
||||
@@ -784,7 +807,62 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
return query, key, value, output_padded
|
||||
|
||||
def _gather_and_restore_pcp_qkv(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
):
|
||||
"""
|
||||
Gathers QKV chunks from all GPUs in the PCP group and restores the original
|
||||
sequence order for Context Parallelism (CP).
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded
|
||||
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
|
||||
pcp_padded_tokens_fla = attn_metadata.prefill.pcp_metadata.pcp_padded_tokens_fla
|
||||
num_tokens_pcp_padded_fla = num_tokens + pcp_padded_tokens_fla
|
||||
|
||||
qkv_fla = torch.cat(
|
||||
[query.reshape(num_tokens, -1), key.reshape(num_tokens, -1), value.reshape(num_tokens, -1)],
|
||||
dim=-1,
|
||||
)
|
||||
if pcp_padded_tokens_fla > 0:
|
||||
qkv_fla = F.pad(qkv_fla, pad=(0, 0, 0, pcp_padded_tokens_fla), mode="constant", value=0)
|
||||
all_qkv = get_pcp_group().all_gather(qkv_fla[:num_tokens_pcp_padded_fla].contiguous(), dim=0)
|
||||
|
||||
# Restore the original sequence order using pre-computed indices
|
||||
pcp_enter_fa_restore_idx = (
|
||||
attn_metadata.prefill.pcp_metadata.pcp_enter_fa_restore_idx if attn_metadata.prefill.pcp_metadata else None
|
||||
)
|
||||
actual_qkv = torch.index_select(all_qkv, 0, pcp_enter_fa_restore_idx)
|
||||
qkv_fa_padding_workspace = query.new_empty(
|
||||
(num_actual_tokens_pcp_padded, (self.num_heads + 2 * self.num_kv_heads) * self.head_size)
|
||||
)
|
||||
|
||||
decode_offset = attn_metadata.num_decode_tokens * self.pcp_size
|
||||
qkv_fa_padding_workspace[:decode_offset] = actual_qkv[:decode_offset]
|
||||
|
||||
pcp_unpad_mask = attn_metadata.prefill.pcp_metadata.pcp_unpad_mask[attn_metadata.num_decodes * self.pcp_size :]
|
||||
qkv_fa_padding_workspace[decode_offset:][pcp_unpad_mask] = actual_qkv[decode_offset:]
|
||||
|
||||
q, k, v = qkv_fa_padding_workspace.split(
|
||||
[
|
||||
self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return (
|
||||
q.reshape(-1, self.num_heads, self.head_size),
|
||||
k.reshape(-1, self.num_kv_heads, self.head_size),
|
||||
v.reshape(-1, self.num_kv_heads, self.head_size),
|
||||
)
|
||||
|
||||
def _gather_global_context_output(self, local_context_attn_output):
|
||||
if self.dcp_size > 1:
|
||||
@@ -831,8 +909,15 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
pcp_use_hybrid_attn = False
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
|
||||
pcp_use_hybrid_attn = attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn
|
||||
if has_decode:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
if pcp_use_hybrid_attn:
|
||||
decode_query = query[: num_decode_tokens * self.pcp_size : self.pcp_size].contiguous()
|
||||
else:
|
||||
decode_query = query[:num_decode_tokens].contiguous()
|
||||
output_decode = self._forward_decode_pcp_dcp(decode_query, attn_metadata)
|
||||
output[:num_decode_tokens] = output_decode
|
||||
if has_prefill:
|
||||
@@ -849,7 +934,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
|
||||
# qkv init
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
||||
if pcp_use_hybrid_attn:
|
||||
prefill_query = query[self.pcp_size * num_decode_tokens :]
|
||||
else:
|
||||
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
||||
key = key[self.pcp_size * num_decode_tokens :].contiguous()
|
||||
value = value[self.pcp_size * num_decode_tokens :].contiguous()
|
||||
|
||||
@@ -914,5 +1002,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
attn_output_prefill, attn_lse_prefill, context_output, context_lse, prefill_query, attn_metadata
|
||||
)
|
||||
|
||||
if self.pcp_size > 1 and pcp_use_hybrid_attn:
|
||||
# layer_idx != num_layers - 1
|
||||
assert attn_metadata.prefill.pcp_metadata is not None
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
||||
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
|
||||
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx)
|
||||
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
|
||||
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)
|
||||
|
||||
output[num_decode_tokens : attn_output_prefill.shape[0] + num_decode_tokens] = attn_output_prefill
|
||||
return output
|
||||
|
||||
@@ -25,7 +25,12 @@ class AscendPCPMetadata:
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_use_hybrid_attn: bool = False
|
||||
pcp_unpad_mask: torch.Tensor = None
|
||||
pcp_allgather_restore_idx: list[int] | None = None
|
||||
pcp_fa_query_idx: torch.Tensor = None
|
||||
pcp_padded_tokens_fla: int = 0
|
||||
pcp_enter_fa_restore_idx: torch.Tensor = None
|
||||
block_table_cp: torch.Tensor = None
|
||||
valid_block_ids: torch.Tensor = None
|
||||
prefill_q_cum_seqlens: torch.Tensor = None
|
||||
|
||||
@@ -101,6 +101,21 @@ class AscendPrefillContextParallelMetadata:
|
||||
# original max_query_len before pcp split
|
||||
max_query_len_pcp_full: int = 0
|
||||
|
||||
# the following attributes are specifically used in hybrid-attn models.
|
||||
pcp_use_hybrid_attn: bool = False
|
||||
|
||||
pcp_unpad_mask: torch.Tensor = None
|
||||
|
||||
# to get the right order of query in prefill per rank
|
||||
pcp_fa_query_idx: torch.Tensor = None
|
||||
|
||||
# restore the full sequence across all pcp ranks
|
||||
# when entering from linear-attention to attention
|
||||
pcp_enter_fa_restore_idx: torch.Tensor = None
|
||||
|
||||
# the number of tokens padded in linear-attn per rank
|
||||
pcp_padded_tokens_fla: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
Reference in New Issue
Block a user