### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_pcp_group)
|
||||
from vllm.distributed import get_dcp_group, get_decode_context_model_parallel_world_size, get_pcp_group
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,6 +14,7 @@ class AscendPCPMetadata:
|
||||
Stores index tensors and sequence lengths for routing attention
|
||||
computations across PCP ranks during long sequence processing.
|
||||
"""
|
||||
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
@@ -27,7 +25,7 @@ class AscendPCPMetadata:
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_allgather_restore_idx: Optional[list[int]] = None
|
||||
pcp_allgather_restore_idx: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,6 +35,7 @@ class CPChunkedContextMetadata:
|
||||
|
||||
Extends chunked prefill with per-rank chunk information for PCP/DCP.
|
||||
"""
|
||||
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
@@ -47,48 +46,51 @@ class CPChunkedContextMetadata:
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
# for mla DCP & PCP
|
||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
padded_local_cu_seq_lens: torch.Tensor = None
|
||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||
chunk_size: Optional[int] = None
|
||||
cu_seq_lens_lst: list[list[int]] | None = None
|
||||
chunk_size: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
""" Prefill-specific metadata for Ascend attention with Context Parallelism."""
|
||||
"""Prefill-specific metadata for Ascend attention with Context Parallelism."""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
"""Metadata for chunked context processing within prefill phase."""
|
||||
|
||||
actual_chunk_seq_lengths: torch.Tensor
|
||||
actual_seq_lengths_kv: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||
local_total_toks: Optional[int] = None
|
||||
chunked_req_mask: list[bool] | None = None
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
cp_kv_recover_idx_for_chunk: list[int] | None = None
|
||||
kv_inverse_idx_for_chunk: list[int] | None = None
|
||||
batch_chunk_seq_mask: list[bool] | None = None
|
||||
local_total_toks: int | None = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
pcp_metadata: AscendPCPMetadata | None = None
|
||||
chunked_context: ChunkedContextMetadata | None = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode-specific metadata for Ascend attention with Context Parallelism."""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
"""Decode-specific metadata for Ascend attention with Context Parallelism."""
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
|
||||
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
|
||||
batch_seq_mask: torch.Tensor) -> torch.Tensor:
|
||||
def _process_attn_out_lse(
|
||||
attn_output: torch.Tensor, softmax_lse: torch.Tensor, batch_seq_mask: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
pcp_size = get_pcp_group().world_size
|
||||
dcp_size = get_decode_context_model_parallel_world_size()
|
||||
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
|
||||
@@ -104,21 +106,17 @@ def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
|
||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||
dist.all_to_all_single(attn_out_lse_all2all,
|
||||
attn_out_lse,
|
||||
group=dcp_group)
|
||||
dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
|
||||
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
|
||||
if pcp_size > 1:
|
||||
# AllGather out&lse within CP group
|
||||
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(),
|
||||
dim=0)
|
||||
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)
|
||||
|
||||
return attn_out_lse
|
||||
|
||||
|
||||
def _npu_attention_update(head_size,
|
||||
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||
def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||
pcp_size = get_pcp_group().world_size
|
||||
dcp_size = get_decode_context_model_parallel_world_size()
|
||||
# [PCP * S, DCP * H, D+1]
|
||||
@@ -134,8 +132,7 @@ def _npu_attention_update(head_size,
|
||||
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||
x = x.view(-1, S, H, D_plus_1)
|
||||
# Split out lse
|
||||
out_flat, lse_flat = torch.split(x, [D, 1],
|
||||
dim=-1) # [N, S, H, D], [N, S, H, 1]
|
||||
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1]
|
||||
# out: [N, S, H, D] -> [N, S*H, D]
|
||||
# lse: [N, S, H, 1] -> [N, S*H]
|
||||
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
|
||||
|
||||
Reference in New Issue
Block a user