[Perf] Avoid performing index selection of sin/cos cache every layer (#1890)

Optimize number of index selections of sin/cos cache.

- vLLM version: v0.10.0
- vLLM main:
656c24f1b5

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-07-29 18:06:45 +08:00
committed by GitHub
parent 0190b68f51
commit 98cadc2146
3 changed files with 73 additions and 22 deletions

View File

@@ -331,15 +331,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
runner.chunked_prefill_enabled = False
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.dtype = torch.float16
builder = AscendMLAMetadataBuilder(runner=runner,
metadata_cls=AscendMLAMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
metadata = builder.build_torchair_graph_dummy(3, 3)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
@@ -354,6 +369,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
class TestAscendMLAImpl(TestBase):