[Refactor] Replace npu_ring_mla with FIA in MLA prefill (#5704)
### What this PR does / why we need it? **Refactor: Replace npu_ring_mla with FIA in MLA prefill** This PR refactors the MLA (Multi-Layer Attention) prefill implementation by replacing `npu_ring_mla` with `npu_fused_infer_attention_score` (FIA) operator, unifying the attention backend with the standard attention implementation. **Key changes:** 1. **Core prefill refactoring (`mla_v1.py`)** - Replace `npu_ring_mla` with `npu_fused_infer_attention_score` in `_forward_prefill` and `_compute_prefill_context` - Use TND layout with `softmax_lse_flag=True` for prefill attention - Use `npu_attention_update` to merge multiple chunk outputs with LSE (Log-Sum-Exp) - Change `attn_mask` from `get_final_mla_mask()` to `get_splitfuse_attn_mask()` for FIA compatibility 2. **Data type handling** - Add automatic float16 → bfloat16 conversion (FIA with TND layout only supports bfloat16) - Convert output back to original dtype after FIA computation 3. **Metadata optimization** - Pre-calculate `actual_seq_lengths_q` in `AscendMLAPrefillMetadata` - Pre-calculate `chunk_actual_seq_lengths_kv_list` in `ChunkedContextMetadata` - Move `torch.cumsum` operations from forward pass to metadata building phase 4. **CP compatibility (`mla_cp.py`)** - Add `_ring_mla_mask_builder` to get `npu_ring_mla`-compatible masks for Context Parallel scenarios - Add `chunk_actual_seq_lengths_kv_list` field to `CPChunkedContextMetadata` **Why we need it:** - **Backend unification**: Aligns MLA prefill with standard attention implementation (`attention_v1.py`) - **Better chunked context support**: FIA + `npu_attention_update` provides native LSE-based output merging - **Future compatibility**: Prepares for eventual `npu_ring_mla` removal across the codebase ### Does this PR introduce _any_ user-facing change? **No.** This is a pure refactoring with no functional changes - same behavior, unified backend. --- - Related issue: #5463 (item 7) - vLLM version: v0.14.1 Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -53,6 +53,7 @@ class CPChunkedContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
chunk_actual_seq_lengths_kv_list: list[list[int]]
|
||||
# for mla DCP & PCP
|
||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm_ascend.attention.mla_v1 import (
|
||||
# isort: on
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata,
|
||||
CPChunkedContextMetadata,
|
||||
@@ -189,6 +190,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
max_seq_lens=chunked_context_metadata.max_seq_lens,
|
||||
chunk_seq_lens=self.chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
|
||||
chunk_actual_seq_lengths_kv_list=chunked_context_metadata.chunk_actual_seq_lengths_kv_list,
|
||||
workspace=chunked_context_metadata.workspace,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
@@ -276,6 +278,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# npu_ring_mla needs bfloat16 512x512 mask, different from FIA's int8 2048x2048 mask
|
||||
# TODO: Remove this when mla_cp.py also migrates to FIA
|
||||
self._ring_mla_mask_builder = AttentionMaskBuilder(torch.device("npu"))
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
@@ -484,6 +490,10 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
# Use ring_mla-specific mask (bfloat16, 512x512)
|
||||
# TODO: Remove this when mla_cp.py migrates to FIA
|
||||
ring_mla_mask = self._ring_mla_mask_builder.get_mla_mask(self.vllm_config.model_config.dtype)
|
||||
|
||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -494,7 +504,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=attn_metadata.attn_mask,
|
||||
mask=ring_mla_mask,
|
||||
)
|
||||
|
||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||
@@ -507,7 +517,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
mask=attn_metadata.attn_mask,
|
||||
mask=ring_mla_mask,
|
||||
)
|
||||
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
|
||||
@@ -112,6 +112,7 @@ class ChunkedContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
chunk_actual_seq_lengths_kv_list: list[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -131,6 +132,7 @@ class AscendMLAPrefillMetadata:
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
pcp_metadata: AscendPCPMetadata | None = None
|
||||
actual_seq_lengths_q: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -447,7 +449,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
num_decodes=self.num_decodes,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
num_prefills=self.num_prefills,
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
@@ -486,6 +488,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True)
|
||||
torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32)
|
||||
chunk_actual_seq_lengths_kv_list = [
|
||||
torch.cumsum(self.chunk_seq_lens[i], dim=0).tolist() for i in range(self.num_chunks)
|
||||
]
|
||||
return ChunkedContextMetadata(
|
||||
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
|
||||
starts=chunk_starts.pin_memory().to(self.device, non_blocking=True),
|
||||
@@ -494,6 +499,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
chunk_seq_lens=self.chunk_seq_lens,
|
||||
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
|
||||
)
|
||||
|
||||
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
|
||||
@@ -527,9 +533,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
prefill_query_lens = self.query_lens[reqs_start:].to(torch.int32)
|
||||
actual_seq_lengths_q = torch.cumsum(prefill_query_lens, dim=0).tolist()
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
query_lens=prefill_query_lens,
|
||||
seq_lens=self.seq_lens,
|
||||
context_lens=self.seq_lens[reqs_start:],
|
||||
input_positions=prefill_input_positions,
|
||||
@@ -540,6 +548,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
)
|
||||
|
||||
def build_decode_metadata(
|
||||
@@ -887,8 +896,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
||||
assert self.fused_qkv_a_proj is not None
|
||||
assert self.q_a_layernorm is not None
|
||||
assert self.kv_a_layernorm is not None
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous()
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
@@ -990,17 +1002,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
|
||||
current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
|
||||
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
|
||||
|
||||
chunk_outputs = []
|
||||
chunk_lses = []
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
# chunk_seq_lens will be padded when pcp&dcp
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata)
|
||||
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
@@ -1026,27 +1039,61 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse,
|
||||
actual_seq_lengths_kv = prefill_metadata.chunked_context.chunk_actual_seq_lengths_kv_list[i]
|
||||
|
||||
chunk_out, chunk_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
v,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
sparse_mode=0,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
)
|
||||
chunk_outputs.append(chunk_out)
|
||||
chunk_lses.append(chunk_lse)
|
||||
|
||||
if len(chunk_outputs) > 0:
|
||||
num_tokens = q_nope.size(0)
|
||||
D = self.v_head_dim
|
||||
H = self.num_heads
|
||||
|
||||
# Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1]
|
||||
prefix_output = prefix_output.to(torch.float32)
|
||||
prefix_lse = prefix_lse.to(torch.float32)
|
||||
if prefix_lse.dim() == 2:
|
||||
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
||||
|
||||
# Concat output and lse: [num_tokens, H, D+1]
|
||||
all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)]
|
||||
for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses):
|
||||
chunk_out = chunk_out.to(torch.float32)
|
||||
chunk_lse = chunk_lse.to(torch.float32)
|
||||
if chunk_lse.dim() == 2:
|
||||
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
||||
all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1))
|
||||
|
||||
# Stack and split: [N, num_tokens, H, D+1]
|
||||
all_out_lse = torch.stack(all_out_lse, dim=0)
|
||||
N = all_out_lse.size(0)
|
||||
out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1)
|
||||
|
||||
# Flatten and unbind for npu_attention_update
|
||||
out_list = out_flat.view(N, num_tokens * H, D).unbind(0)
|
||||
lse_list = lse_flat.view(N, num_tokens * H).unbind(0)
|
||||
|
||||
output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
||||
return output_final.view(num_tokens, H, D), None
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
@@ -1062,33 +1109,54 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert attn_metadata.prefill is not None
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
num_tokens = q_nope.size(0)
|
||||
prefill_meta = attn_metadata.prefill
|
||||
|
||||
actual_seq_lengths_q = prefill_meta.actual_seq_lengths_q
|
||||
actual_seq_lengths_kv = actual_seq_lengths_q.copy()
|
||||
|
||||
# FIA with TND layout only supports bfloat16, convert if needed
|
||||
original_dtype = q_nope.dtype
|
||||
need_dtype_convert = original_dtype != torch.bfloat16
|
||||
if need_dtype_convert:
|
||||
q_nope = q_nope.to(torch.bfloat16)
|
||||
q_pe = q_pe.to(torch.bfloat16)
|
||||
k_nope = k_nope.to(torch.bfloat16)
|
||||
k_pe = k_pe.to(torch.bfloat16)
|
||||
value = value.to(torch.bfloat16)
|
||||
|
||||
attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seqlen=attn_metadata.prefill.query_lens,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse,
|
||||
)
|
||||
|
||||
common_kwargs = {
|
||||
"query_rope": q_pe,
|
||||
"key_rope": k_pe,
|
||||
"num_heads": self.num_heads,
|
||||
"num_key_value_heads": self.num_heads,
|
||||
"input_layout": "TND",
|
||||
"atten_mask": prefill_meta.attn_mask,
|
||||
"sparse_mode": 3,
|
||||
"scale": self.scale,
|
||||
"antiquant_mode": 0,
|
||||
"antiquant_scale": None,
|
||||
"block_table": None,
|
||||
"block_size": 0,
|
||||
"softmax_lse_flag": True,
|
||||
"actual_seq_lengths": actual_seq_lengths_q,
|
||||
"actual_seq_lengths_kv": actual_seq_lengths_kv,
|
||||
}
|
||||
|
||||
attn_output, attn_lse = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, value, **common_kwargs)
|
||||
|
||||
attn_output, attn_lse = self._compute_prefill_context(
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
if need_dtype_convert:
|
||||
attn_output = attn_output.to(original_dtype)
|
||||
|
||||
return attn_output
|
||||
|
||||
def exec_kv_decode(
|
||||
@@ -1099,6 +1167,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_cache: tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
assert self.kv_a_layernorm is not None
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
@@ -1126,6 +1195,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_cache: tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
assert self.kv_a_layernorm is not None
|
||||
B = kv_no_split.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
|
||||
Reference in New Issue
Block a user