Files

149 lines
5.5 KiB
Python
Raw Permalink Normal View History

from dataclasses import dataclass
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import get_dcp_group, get_decode_context_model_parallel_world_size, get_pcp_group
@dataclass
class AscendPCPMetadata:
"""
Metadata for Prefill Context Parallelism (PCP) on Ascend devices.
Stores index tensors and sequence lengths for routing attention
computations across PCP ranks during long sequence processing.
"""
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_use_hybrid_attn: bool = False
pcp_unpad_mask: torch.Tensor = None
pcp_allgather_restore_idx: list[int] | None = None
pcp_fa_query_idx: torch.Tensor = None
pcp_padded_tokens_fla: int = 0
pcp_enter_fa_restore_idx: torch.Tensor = None
block_table_cp: torch.Tensor = None
valid_block_ids: torch.Tensor = None
prefill_q_cum_seqlens: torch.Tensor = None
[Feat]Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. (#6940) ### What this PR does / why we need it? Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. ### How was this patch tested? Test output: {"object":"text_completion","model":"deepeek_v3","choices":[{"index":0,"text":" the head of state and head of government of the United States, indirectly elected to a four-year term by the American people through the Electoral College. The officeholder leads the executive branch of the federal government and is the commander-in-chief of the United States","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":1,"text":" Paris. This is the largest city in France and its main political, cultural and commercial center. The modern location of the city is the north of the central part of the country, on the banks of the Seine River Seine River Seine in 3\n\n","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":2,"text":" now\n\n# AI future is now\n\nThe world is changing at a rapid pace, and artificial intelligence (AI) is at the forefront of this transformation. From self-driving cars to virtual assistants, AI is already making a significant impact on our daily lives","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":3,"text":" a 3rd year student at the University of Lincoln studying Media Production. This blog is about my work throughout my final year on the course.\n\n## Tuesday 3 May 2016\n### Final Major Project - Evaluation\n\nFor my final project I","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":27,"total_tokens":227,"completion_tokens":200,"prompt_tokens_details":null},"kv_transfer_params":null} - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: xiaocongtou6 <2066962956@qq.com> Signed-off-by: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com>
2026-03-06 16:10:24 +08:00
block_arange: torch.Tensor = None
@dataclass
class CPChunkedContextMetadata:
"""
Metadata for chunked context handling in Context Parallelism (CP).
Extends chunked prefill with per-rank chunk information for PCP/DCP.
"""
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704) ### What this PR does / why we need it? **Refactor: Replace npu_ring_mla with FIA in MLA prefill** This PR refactors the MLA (Multi-Layer Attention) prefill implementation by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA) operator, unifying the attention backend with the standard attention implementation. **Key changes:** 1. **Core prefill refactoring (`mla_v1.py`)** - Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in `_forward_prefill` and `_compute_prefill_context` - Use TND layout with `softmax_lse_flag=True` for prefill attention - Use `npu_attention_update` to merge multiple chunk outputs with LSE (Log-Sum-Exp) - Change `attn_mask` from `get_final_mla_mask()` to `get_splitfuse_attn_mask()` for FIA compatibility 2. **Data type handling** - Add automatic float16 → bfloat16 conversion (FIA with TND layout only supports bfloat16) - Convert output back to original dtype after FIA computation 3. **Metadata optimization** - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata` - Pre-calculate `chunk_actual_seq_lengths_kv_list` in `ChunkedContextMetadata` - Move `torch.cumsum` operations from forward pass to metadata building phase 4. **CP compatibility (`mla_cp.py`)** - Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks for Context Parallel scenarios - Add `chunk_actual_seq_lengths_kv_list` field to `CPChunkedContextMetadata` **Why we need it:** - **Backend unification**: Aligns MLA prefill with standard attention implementation (`attention_v1.py`) - **Better chunked context support**: FIA + `npu_attention_update` provides native LSE-based output merging - **Future compatibility**: Prepares for eventual `npu_ring_mla` removal across the codebase ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, unified backend. --- - Related issue: #5463 (item 7) - vLLM version: v0.14.1 Signed-off-by: lico67373 <918688502@qq.com>
2026-03-16 10:33:09 +08:00
chunk_actual_seq_lengths_kv_list: list[list[int]]
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: list[list[int]] | None = None
local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
@dataclass
class AscendMetadataForPrefill:
"""Prefill-specific metadata for Ascend attention with Context Parallelism."""
@dataclass
class ChunkedContextMetadata:
"""Metadata for chunked context processing within prefill phase."""
actual_chunk_seq_lengths: torch.Tensor
actual_seq_lengths_kv: torch.Tensor
starts: torch.Tensor
chunk_seq_mask_filtered_indices: torch.Tensor
chunked_req_mask: list[bool] | None = None
local_context_lens_allranks: list[list[int]] | None = None
cp_kv_recover_idx_for_chunk: list[int] | None = None
kv_inverse_idx_for_chunk: list[int] | None = None
local_total_toks: int | None = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: AscendPCPMetadata | None = None
pcp_exit_fa_scatter_idx: torch.Tensor | None = None
chunked_context: ChunkedContextMetadata | None = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
class AscendMetadataForDecode:
"""Decode-specific metadata for Ascend attention with Context Parallelism."""
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
block_tables: torch.Tensor = None
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)
return attn_out_lse
def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
# [PCP * S, DCP * H, D+1]
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // pcp_size
H = H_total // dcp_size
D = head_size
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
# unbind to list
out_list = out_flat.unbind(0) # [S*H, D]
lse_list = lse_flat.unbind(0) # [S*H]
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
attn_out = attn_out.view(-1, H, D)
return attn_out