[Perf] Avoid performing index selection of sin/cos cache every layer (#1890)

Optimize number of index selections of sin/cos cache.

- vLLM version: v0.10.0
- vLLM main:
656c24f1b5

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-07-29 18:06:45 +08:00
committed by GitHub
parent 0190b68f51
commit 98cadc2146
3 changed files with 73 additions and 22 deletions

View File

@@ -331,15 +331,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
runner.chunked_prefill_enabled = False
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.dtype = torch.float16
builder = AscendMLAMetadataBuilder(runner=runner,
metadata_cls=AscendMLAMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
metadata = builder.build_torchair_graph_dummy(3, 3)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
@@ -354,6 +369,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
class TestAscendMLAImpl(TestBase):

View File

@@ -80,6 +80,8 @@ class AscendMLAPrefillMetadata:
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
@dataclass
@@ -92,6 +94,8 @@ class AscendMLADecodeMetadata:
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
@dataclass
@@ -200,6 +204,9 @@ class AscendMLAMetadataBuilder:
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -318,13 +325,27 @@ class AscendMLAMetadataBuilder:
-1,
dtype=torch.int32,
device=device)
sin = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
cos = torch.ones(num_reqs,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1,
attn_mask=self.runner.spec_attn_mask)
attn_mask=self.runner.spec_attn_mask,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
@@ -370,6 +391,16 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
if self.cos_cache is None:
self.cos_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.cos_cached
self.sin_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.runner.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.runner.dtype) # type: ignore
prefill_metadata = None
chunked_context_metadata = None
@@ -415,18 +446,26 @@ class AscendMLAMetadataBuilder:
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:],
input_positions=input_positions[tokens_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
decode_metadata = None
@@ -467,6 +506,10 @@ class AscendMLAMetadataBuilder:
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
@@ -474,7 +517,9 @@ class AscendMLAMetadataBuilder:
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens,
attn_mask=self.runner.spec_attn_mask)
attn_mask=self.runner.spec_attn_mask,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
@@ -1069,15 +1114,8 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_k_nope = None
assert attn_metadata.decode is not None
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_hs_or_q_c.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
@@ -1124,15 +1162,8 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.torchair_graph_enabled:
num_tokens = prefill_hs_or_q_c.shape[0]
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=prefill_q_pe.dtype)
cos = cos[attn_metadata.prefill.input_positions]
sin = sin[attn_metadata.prefill.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(

View File

@@ -1799,6 +1799,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(