[v0.18.0][BugFix]Revert the code: Replace npu_ring_mla wit FIA with MLA prefill. (#7961)
This pull request reverts previous changes to switch to FIA and instead implements npu_ring_mla for MLA prefill operations(#5704 ). The change streamlines the attention mechanism by removing unnecessary metadata tracking and updating the underlying NPU operations to use the ring-based MLA kernel. This adjustment ensures better compatibility and performance for MLA prefill tasks within the vLLM Ascend backend. Highlights - Migration to npu_ring_mla: Replaced the usage of npu_fused_infer_attention_score (FIA) with npu_ring_mla for MLA prefill operations across the codebase to improve performance and alignment with the intended architecture. - Cleanup of redundant metadata: Removed chunk_actual_seq_lengths_kv_list and actual_seq_lengths_q from various metadata structures as they are no longer required for the updated attention implementation. - Test suite updates: Updated unit tests in test_mla_cp.py and test_mla_v1.py to mock npu_ring_mla instead of the deprecated FIA functions and adjusted test assertions to reflect the new implementation details. Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -114,7 +114,6 @@ class ChunkedContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
chunk_actual_seq_lengths_kv_list: list[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -134,7 +133,6 @@ class AscendMLAPrefillMetadata:
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
pcp_metadata: AscendPCPMetadata | None = None
|
||||
actual_seq_lengths_q: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -452,7 +450,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
num_decodes=self.num_decodes,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
num_prefills=self.num_prefills,
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
@@ -492,9 +490,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True)
|
||||
torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32)
|
||||
chunk_actual_seq_lengths_kv_list = [
|
||||
torch.cumsum(self.chunk_seq_lens[i], dim=0).tolist() for i in range(self.num_chunks)
|
||||
]
|
||||
return ChunkedContextMetadata(
|
||||
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True),
|
||||
starts=chunk_starts.pin_memory().to(self.device, non_blocking=True),
|
||||
@@ -503,7 +498,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
chunk_seq_lens=self.chunk_seq_lens,
|
||||
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
chunk_actual_seq_lengths_kv_list=chunk_actual_seq_lengths_kv_list,
|
||||
)
|
||||
|
||||
def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int):
|
||||
@@ -538,9 +532,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
prefill_query_lens = self.query_lens[reqs_start:].to(torch.int32)
|
||||
actual_seq_lengths_q = torch.cumsum(prefill_query_lens, dim=0).tolist()
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config),
|
||||
query_lens=prefill_query_lens,
|
||||
seq_lens=self.seq_lens,
|
||||
context_lens=self.seq_lens[reqs_start:],
|
||||
@@ -552,7 +545,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
)
|
||||
|
||||
def build_decode_metadata(
|
||||
@@ -1056,29 +1048,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
|
||||
current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
|
||||
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
|
||||
|
||||
if iters == 0:
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
num_tokens = q_nope.size(0)
|
||||
D = self.v_head_dim
|
||||
H = self.num_heads
|
||||
|
||||
if prefix_lse.dim() == 2:
|
||||
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
||||
prefix_output = prefix_output.to(torch.float32)
|
||||
prefix_lse = prefix_lse.to(torch.float32)
|
||||
out_list = [prefix_output.reshape(num_tokens * H, D)]
|
||||
lse_list = [prefix_lse.reshape(num_tokens * H)]
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
# chunk_seq_lens will be padded when pcp&dcp
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata)
|
||||
kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
@@ -1104,35 +1085,29 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
actual_seq_lengths_kv = prefill_metadata.chunked_context.chunk_actual_seq_lengths_kv_list[i]
|
||||
|
||||
chunk_out, chunk_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
v,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
sparse_mode=0,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse,
|
||||
)
|
||||
if chunk_lse.dim() == 2:
|
||||
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
||||
chunk_out = chunk_out.to(torch.float32)
|
||||
chunk_lse = chunk_lse.to(torch.float32)
|
||||
out_list.append(chunk_out.reshape(num_tokens * H, D))
|
||||
lse_list.append(chunk_lse.reshape(num_tokens * H))
|
||||
|
||||
output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0)
|
||||
return output_final.view(num_tokens, H, D), None
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
@@ -1147,54 +1122,35 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert attn_metadata.prefill is not None
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
num_tokens = q_nope.size(0)
|
||||
prefill_meta = attn_metadata.prefill
|
||||
|
||||
actual_seq_lengths_q = prefill_meta.actual_seq_lengths_q
|
||||
actual_seq_lengths_kv = actual_seq_lengths_q.copy()
|
||||
|
||||
# FIA with TND layout only supports bfloat16, convert if needed
|
||||
original_dtype = q_nope.dtype
|
||||
need_dtype_convert = original_dtype != torch.bfloat16
|
||||
if need_dtype_convert:
|
||||
q_nope = q_nope.to(torch.bfloat16)
|
||||
q_pe = q_pe.to(torch.bfloat16)
|
||||
k_nope = k_nope.to(torch.bfloat16)
|
||||
k_pe = k_pe.to(torch.bfloat16)
|
||||
value = value.to(torch.bfloat16)
|
||||
|
||||
attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device)
|
||||
attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device)
|
||||
|
||||
common_kwargs = {
|
||||
"query_rope": q_pe,
|
||||
"key_rope": k_pe,
|
||||
"num_heads": self.num_heads,
|
||||
"num_key_value_heads": self.num_heads,
|
||||
"input_layout": "TND",
|
||||
"atten_mask": prefill_meta.attn_mask,
|
||||
"sparse_mode": 3,
|
||||
"scale": self.scale,
|
||||
"antiquant_mode": 0,
|
||||
"antiquant_scale": None,
|
||||
"block_table": None,
|
||||
"block_size": 0,
|
||||
"softmax_lse_flag": True,
|
||||
"actual_seq_lengths": actual_seq_lengths_q,
|
||||
"actual_seq_lengths_kv": actual_seq_lengths_kv,
|
||||
}
|
||||
|
||||
attn_output, attn_lse = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, value, **common_kwargs)
|
||||
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seqlen=attn_metadata.prefill.query_lens,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse,
|
||||
)
|
||||
attn_output, attn_lse = self._compute_prefill_context(
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
if need_dtype_convert:
|
||||
attn_output = attn_output.to(original_dtype)
|
||||
|
||||
return attn_output
|
||||
|
||||
def exec_kv_decode(
|
||||
|
||||
Reference in New Issue
Block a user