[Perf] Avoid performing index selection of sin/cos cache every layer (#1890)

Optimize number of index selections of sin/cos cache.

- vLLM version: v0.10.0
- vLLM main:
656c24f1b5

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-07-29 18:06:45 +08:00
committed by GitHub
parent 0190b68f51
commit 98cadc2146
3 changed files with 73 additions and 22 deletions

View File

@@ -80,6 +80,8 @@ class AscendMLAPrefillMetadata:
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
@dataclass
@@ -92,6 +94,8 @@ class AscendMLADecodeMetadata:
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
@dataclass
@@ -200,6 +204,9 @@ class AscendMLAMetadataBuilder:
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -318,13 +325,27 @@ class AscendMLAMetadataBuilder:
-1,
dtype=torch.int32,
device=device)
sin = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
cos = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1,
attn_mask=self.runner.spec_attn_mask)
attn_mask=self.runner.spec_attn_mask,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
@@ -370,6 +391,16 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
if self.cos_cache is None:
self.cos_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.cos_cached
self.sin_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.runner.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.runner.dtype) # type: ignore
prefill_metadata = None
chunked_context_metadata = None
@@ -415,18 +446,26 @@ class AscendMLAMetadataBuilder:
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:],
input_positions=input_positions[tokens_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
decode_metadata = None
@@ -467,6 +506,10 @@ class AscendMLAMetadataBuilder:
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
@@ -474,7 +517,9 @@ class AscendMLAMetadataBuilder:
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens,
attn_mask=self.runner.spec_attn_mask)
attn_mask=self.runner.spec_attn_mask,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
@@ -1069,15 +1114,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
@@ -1124,15 +1162,8 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(