[Feat]ds3.2 support pcp (#6733)
### What this PR does / why we need it?
The ds3.2 model adaptation supports the PCP feature.
The solution is as follows: When saving the KV cache, first perform an
allgather operation on the KVs, and then each node saves its own copy.
When the attention or indexer performs calculations, they all gather the
KV cache and then perform the calculations.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
02/12 23:05:10 - AISBench - INFO - Running 1-th replica of evaluation
02/12 23:05:10 - AISBench - INFO - Task [vllm-api-general-chat/gsm8k]:
{'accuracy': 96.35416666666667, 'type': 'GEN'}
02/12 23:05:10 - AISBench - INFO - time elapsed: 2.87s
02/12 23:05:12 - AISBench - INFO - Evaluation tasks completed.
02/12 23:05:12 - AISBench - INFO - Summarizing evaluation results...
dataset version metric mode vllm-api-general-chat
gsm8kdataset - accuracy gen 96.35
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -56,6 +56,20 @@ def test_models_pcp_dcp_basic():
|
||||
quantization="ascend",
|
||||
) as runner:
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning"
|
||||
with VllmRunner(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
quantization="ascend",
|
||||
) as runner:
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
def test_models_pcp_dcp_full_graph():
|
||||
|
||||
@@ -26,6 +26,9 @@ class AscendPCPMetadata:
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_allgather_restore_idx: list[int] | None = None
|
||||
block_table_cp: torch.Tensor = None
|
||||
valid_block_ids: torch.Tensor = None
|
||||
prefill_q_cum_seqlens: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
440
vllm_ascend/attention/context_parallel/sfa_cp.py
Normal file
440
vllm_ascend/attention/context_parallel/sfa_cp.py
Normal file
@@ -0,0 +1,440 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
|
||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
||||
from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[AscendSFAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device, metadata_cls, supports_dcp_with_varlen)
|
||||
|
||||
# In sfa, pcp prefill does not support mlapo
|
||||
self.enable_mlapo = enabling_mlapo(self.vllm_config)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
|
||||
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
|
||||
self.block_size, self.cp_virtual_block_size
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AscendSFAMetadata:
|
||||
metadata_cls = super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.decode_threshold
|
||||
)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens
|
||||
|
||||
block_table = metadata_cls.block_table
|
||||
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
|
||||
num_blocks = valid_block_ids.shape[0]
|
||||
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
|
||||
# We assume that we can always get the correct kv_lens or kv index,
|
||||
# so we omit the dirty value processing here.
|
||||
block_table_cp = (
|
||||
new_block_table.unsqueeze(-1).to(block_table)
|
||||
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
|
||||
).reshape(block_table.shape[0], -1)
|
||||
|
||||
sfa_cp_metadata = self.build_cp_metadata(
|
||||
block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata
|
||||
)
|
||||
metadata_cls.num_decode_tokens = num_decode_tokens
|
||||
metadata_cls.num_decodes = num_decodes
|
||||
metadata_cls.num_prefills = num_prefills
|
||||
|
||||
if self.pcp_size > 1:
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert long_seq_metadata is not None
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
|
||||
if self.enable_mlapo:
|
||||
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
|
||||
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
|
||||
metadata_cls.slot_mapping = slot_mapping
|
||||
actual_seq_lengths_query = metadata_cls.cum_query_lens
|
||||
if num_prefills > 0 and num_decode_tokens > 0:
|
||||
prefill_q_cum_seqlens = (
|
||||
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
|
||||
)
|
||||
else:
|
||||
prefill_q_cum_seqlens = actual_seq_lengths_query
|
||||
assert sfa_cp_metadata is not None
|
||||
sfa_cp_metadata.prefill_q_cum_seqlens = prefill_q_cum_seqlens
|
||||
metadata_cls.sfa_cp_metadata = sfa_cp_metadata
|
||||
return metadata_cls
|
||||
|
||||
def build_cp_metadata(
|
||||
self,
|
||||
block_table_cp: torch.Tensor,
|
||||
valid_block_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
) -> AscendPCPMetadata | None:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert common_long_seq_metadata is not None
|
||||
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
|
||||
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
|
||||
return AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
head_attn_nomask_seqlens=q_head_kv_lens,
|
||||
tail_attn_nomask_seqlens=q_tail_kv_lens,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
|
||||
block_table_cp=block_table_cp,
|
||||
valid_block_ids=valid_block_ids,
|
||||
)
|
||||
|
||||
|
||||
class AscendSFACPImpl(AscendSFAImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**kwargs,
|
||||
)
|
||||
# In sfa, pcp prefill does not support mlapo
|
||||
self.enable_mlapo = enabling_mlapo(self.vllm_config)
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
def _execute_sparse_flash_attention_process(
|
||||
self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
):
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
|
||||
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
|
||||
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
if self.pcp_size == 1:
|
||||
return self._execute_sparse_flash_attention(
|
||||
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
)
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_attn_out = None
|
||||
if num_decode_tokens > 0:
|
||||
decode_attn_out = self._execute_sparse_flash_attention(
|
||||
ql_nope[:num_decode_tokens],
|
||||
q_pe[:num_decode_tokens],
|
||||
kv,
|
||||
key_rope,
|
||||
block_table[:num_decode_tokens],
|
||||
topk_indices[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
)
|
||||
|
||||
if num_prefills < 1:
|
||||
return decode_attn_out
|
||||
|
||||
# q split for head and tail
|
||||
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
|
||||
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
|
||||
ql_nope = ql_nope[num_decode_tokens:]
|
||||
q_pe = q_pe[num_decode_tokens:]
|
||||
topk_indices = topk_indices[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_head_idx),
|
||||
torch.index_select(q_pe, 0, q_head_idx),
|
||||
kv,
|
||||
key_rope,
|
||||
block_table,
|
||||
torch.index_select(topk_indices, 0, q_head_idx),
|
||||
attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2,
|
||||
q_head_actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_tail_idx),
|
||||
torch.index_select(q_pe, 0, q_tail_idx),
|
||||
kv,
|
||||
key_rope,
|
||||
block_table,
|
||||
torch.index_select(topk_indices, 0, q_tail_idx),
|
||||
attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2,
|
||||
q_tail_actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
q_full_idx = attn_metadata.sfa_cp_metadata.q_full_idx
|
||||
attn_output = torch.index_select(torch.cat([q_head_output, q_tail_output], dim=0), 0, q_full_idx)
|
||||
|
||||
if decode_attn_out is not None:
|
||||
attn_output = torch.cat([decode_attn_out, attn_output], dim=0)
|
||||
return attn_output
|
||||
|
||||
def _execute_sparse_flash_attention(
|
||||
self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
):
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv,
|
||||
value=kv,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=block_table,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=key_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
|
||||
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
|
||||
if self.dcp_size > 1:
|
||||
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
|
||||
if self.pcp_size > 1:
|
||||
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
|
||||
return kv_cache
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
q: torch.Tensor | None,
|
||||
k: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
actual_seq_lengths_query: torch.Tensor,
|
||||
actual_seq_lengths_key: torch.Tensor,
|
||||
need_gather_q_kv: bool = False,
|
||||
):
|
||||
if q is None:
|
||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||
cos_q, sin_q = cos, sin
|
||||
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1
|
||||
) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
|
||||
) # b, s, n, d
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
|
||||
|
||||
key = kv_cache[2]
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
if self.pcp_size == 1:
|
||||
return self._execute_indexer_select(
|
||||
q, key, weights, actual_seq_lengths_query, actual_seq_lengths_key, block_table
|
||||
)
|
||||
|
||||
# decode compute
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_topk_indices = None
|
||||
if num_decode_tokens > 0:
|
||||
decode_topk_indices = self._execute_indexer_select(
|
||||
q[:num_decode_tokens],
|
||||
key,
|
||||
weights[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
block_table[:num_decode_tokens],
|
||||
)
|
||||
|
||||
# prefill compute
|
||||
if num_prefills == 0:
|
||||
return decode_topk_indices
|
||||
q = q[num_decode_tokens:]
|
||||
weights = weights[num_decode_tokens:]
|
||||
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
# pcp split for head and tail
|
||||
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
|
||||
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_head_idx),
|
||||
key=key,
|
||||
weights=torch.index_select(weights, 0, q_head_idx),
|
||||
actual_seq_lengths_query=attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2,
|
||||
actual_seq_lengths_key=q_head_actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_tail_idx),
|
||||
key=key,
|
||||
weights=torch.index_select(weights, 0, q_tail_idx),
|
||||
actual_seq_lengths_query=attn_metadata.sfa_cp_metadata.prefill_q_cum_seqlens // 2,
|
||||
actual_seq_lengths_key=q_tail_actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
)
|
||||
|
||||
q_full_idx = attn_metadata.sfa_cp_metadata.q_full_idx
|
||||
topk_indices = torch.index_select(torch.cat([q_head_topk_indices, q_tail_topk_indices], dim=0), 0, q_full_idx)
|
||||
if decode_topk_indices is not None:
|
||||
topk_indices = torch.cat([decode_topk_indices, topk_indices], dim=0)
|
||||
return topk_indices
|
||||
|
||||
def _execute_indexer_select(self, q, key, weights, actual_seq_lengths_query, actual_seq_lengths_key, block_table):
|
||||
if self.use_torch_npu_lightning_indexer:
|
||||
topk_indices, _ = torch_npu.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=key,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3,
|
||||
)
|
||||
else:
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=key,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3,
|
||||
)
|
||||
return topk_indices
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
kv_no_split: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: tuple,
|
||||
slots: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
):
|
||||
if self.pcp_size == 1:
|
||||
return super().exec_kv(kv_no_split, cos, sin, kv_cache, slots, attn_metadata)
|
||||
kv_c, k_pe = kv_no_split.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) # type: ignore[misc]
|
||||
assert len(kv_cache) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
kv_c_normed = kv_c_normed.view([kv_c_normed.shape[0], self.num_kv_heads, -1])
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
k_pe = self.rope_single(k_pe, cos, sin)
|
||||
kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
|
||||
kv_c_k_pe = get_pcp_group().all_gather(kv_c_k_pe, 0)
|
||||
kv_c_k_pe = torch.index_select(kv_c_k_pe, 0, attn_metadata.sfa_cp_metadata.pcp_allgather_restore_idx)
|
||||
kv_c_normed, k_pe = kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=kv_c_normed, value=k_pe, key_cache=kv_cache[0], value_cache=kv_cache[1], slot_indices=slot_mapping
|
||||
)
|
||||
return None, None
|
||||
|
||||
def _get_full_kv(self, k, attn_metadata: M):
|
||||
if self.pcp_size == 1 or self.enable_mlapo:
|
||||
return k
|
||||
else:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
k = get_pcp_group().all_gather(k.contiguous(), 0)
|
||||
k = torch.index_select(k, 0, attn_metadata.sfa_cp_metadata.pcp_allgather_restore_idx)
|
||||
return k
|
||||
@@ -6,7 +6,7 @@ import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
||||
@@ -19,10 +19,12 @@ from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata,
|
||||
ascend_chunked_prefill_workspace_size,
|
||||
enable_cp,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
trans_rope_weight,
|
||||
transdata,
|
||||
@@ -68,6 +70,10 @@ class AscendSFABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPMetadataBuilder
|
||||
|
||||
return AscendSFACPMetadataBuilder
|
||||
return AscendSFAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
@@ -76,6 +82,10 @@ class AscendSFABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AscendSFAImpl"]:
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPImpl
|
||||
|
||||
return AscendSFACPImpl
|
||||
return AscendSFAImpl
|
||||
|
||||
@staticmethod
|
||||
@@ -95,12 +105,6 @@ class DSACPContext:
|
||||
actual_seq_lengths_key: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFACPMetadata:
|
||||
block_table_cp: torch.Tensor
|
||||
valid_block_ids: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
@@ -133,7 +137,10 @@ class AscendSFAMetadata:
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
dsa_cp_context: DSACPContext | None = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
sfa_cp_metadata: SFACPMetadata | None = None
|
||||
sfa_cp_metadata: AscendPCPMetadata | None = None
|
||||
num_decodes: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -185,14 +192,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
|
||||
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
||||
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||
@@ -309,22 +308,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
sfa_cp_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
|
||||
num_blocks = valid_block_ids.shape[0]
|
||||
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
|
||||
# We assume that we can always get the correct kv_lens or kv index,
|
||||
# so we omit the dirty value processing here.
|
||||
block_table_cp = (
|
||||
new_block_table.unsqueeze(-1).to(block_table)
|
||||
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
|
||||
).reshape(block_table.shape[0], -1)
|
||||
sfa_cp_metadata = SFACPMetadata(
|
||||
block_table_cp=block_table_cp,
|
||||
valid_block_ids=valid_block_ids,
|
||||
)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -338,7 +321,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
sin=sin[:num_input_tokens],
|
||||
cos=cos[:num_input_tokens],
|
||||
dsa_cp_context=dsa_cp_context,
|
||||
sfa_cp_metadata=sfa_cp_metadata,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -453,14 +435,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# NOTE: We currently do not support quant kv_b_proj.
|
||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||
@@ -562,6 +536,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def _get_full_kv(self, k, attn_metadata):
|
||||
return k
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
kv_no_split: torch.Tensor,
|
||||
@@ -569,6 +546,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
sin: torch.Tensor,
|
||||
kv_cache: tuple,
|
||||
slots: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
):
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
@@ -835,7 +813,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query
|
||||
actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key
|
||||
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping)
|
||||
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata)
|
||||
|
||||
if self.enable_dsa_cp:
|
||||
assert k_pe is not None
|
||||
@@ -875,6 +853,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope)
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe)
|
||||
|
||||
k = self._get_full_kv(k, attn_metadata)
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(
|
||||
kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1])
|
||||
@@ -894,31 +873,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
)
|
||||
|
||||
block_table = attn_metadata.block_table
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
|
||||
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
|
||||
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv,
|
||||
value=kv,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=block_table,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=key_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
attn_output = self._execute_sparse_flash_attention_process(
|
||||
ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
)
|
||||
|
||||
attn_output = self._v_up_proj(attn_output)
|
||||
@@ -950,14 +906,30 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
return output_padded
|
||||
|
||||
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
|
||||
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
|
||||
if self.dcp_size > 1:
|
||||
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
|
||||
if self.pcp_size > 1:
|
||||
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
|
||||
return kv_cache
|
||||
def _execute_sparse_flash_attention_process(
|
||||
self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
):
|
||||
block_table = attn_metadata.block_table
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv,
|
||||
value=kv,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=block_table,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=key_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def indexer_select_pre_process(
|
||||
self,
|
||||
@@ -1038,10 +1010,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
key = kv_cache[2]
|
||||
block_table = attn_metadata.block_table
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
# DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer.
|
||||
# So two branches are maintained temporarily.
|
||||
|
||||
Reference in New Issue
Block a user