[Perf] Deepseekv3 performance optimization for eager mode (#598)
### What this PR does / why we need it? Deepseek v3 now adopt vanilla chunked prefill on MLA part which is ineffcient for computing but necessary for chunked prefill. Since PR https://github.com/vllm-project/vllm-ascend/pull/543 bring v0 scheduler into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside the mla backend for more performance boost. Also there are some redundant computation inside the rope, which is also removed. This PR should bring some performance gain for deepseek eager mode inference. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
@@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata:
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_context_len: int
|
||||
max_seq_lens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -65,6 +65,7 @@ class AscendMLADecodeMetadata:
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -131,11 +132,6 @@ class AscendMLAMetadataBuilder:
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
# self.attn_mask = None
|
||||
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
|
||||
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
# 128, self.runner.model_config.dtype
|
||||
# )
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -222,12 +218,14 @@ class AscendMLAMetadataBuilder:
|
||||
num_reqs]
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_context_len = seq_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
max_query_len = query_lens[tokens_start:].max().item()
|
||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
@@ -236,15 +234,17 @@ class AscendMLAMetadataBuilder:
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_context_len=max_context_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decode_tokens, ...],
|
||||
seq_lens=seq_lens[:self._num_decode_tokens])
|
||||
seq_lens=seq_lens[:self._num_decode_tokens],
|
||||
max_seq_lens=max_seq_lens)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -306,12 +306,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# TODO: below padding should be removed after kernel is ready
|
||||
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
|
||||
# and slice the final result to guarantee its functionality.
|
||||
self.padding_head_dim = (
|
||||
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
|
||||
1) * 128
|
||||
|
||||
# Hack for V1 for now to avoid torch library overhead (since we are
|
||||
# already inside an attention custom op), pull out the forward
|
||||
# method from the rotary embedding and call it directly
|
||||
# TODO(lucas): we should probably find a cleaner way to do this
|
||||
self.rotary_emb = rotary_emb.forward_native
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
@@ -409,37 +415,73 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
# TODO: enable this compute for flash attention computation
|
||||
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
num_tokens = query.size(0)
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_context_len,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_output = attn_output.view(
|
||||
attn_output = None
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
|
||||
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.padding_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
pad_query = torch.nn.functional.pad(query, [
|
||||
0, self.padding_head_dim - self.qk_rope_head_dim -
|
||||
self.qk_nope_head_dim
|
||||
],
|
||||
value=0)
|
||||
pad_key = torch.nn.functional.pad(key, [
|
||||
0, self.padding_head_dim - self.qk_rope_head_dim -
|
||||
self.qk_nope_head_dim
|
||||
],
|
||||
value=0)
|
||||
pad_value = torch.nn.functional.pad(
|
||||
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=pad_query,
|
||||
key=pad_key,
|
||||
value=pad_value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.prefill.context_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
out=attn_output)
|
||||
attn_output = attn_output.view(
|
||||
-1, self.num_heads,
|
||||
self.padding_head_dim)[:, :, :self.v_head_dim]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
return self.o_proj(attn_output)[0]
|
||||
|
||||
@@ -457,7 +499,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
num_tokens = q.size(0)
|
||||
attn_output = torch.randn(
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
@@ -522,8 +564,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||
decode_k_pe)
|
||||
attn_metadata.decode.input_positions,
|
||||
decode_q_pe.contiguous(),
|
||||
decode_k_pe,
|
||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
@@ -533,7 +577,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||
prefill_q_pe.contiguous(),
|
||||
prefill_k_pe,
|
||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
|
||||
Reference in New Issue
Block a user