[Perf] Avoid performing index selection of sin/cos cache every layer (#1890)
Optimize number of index selections of sin/cos cache.
- vLLM version: v0.10.0
- vLLM main:
656c24f1b5
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -331,15 +331,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
runner.chunked_prefill_enabled = False
|
||||
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
|
||||
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
|
||||
runner.dtype = torch.float16
|
||||
|
||||
builder = AscendMLAMetadataBuilder(runner=runner,
|
||||
metadata_cls=AscendMLAMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
metadata = builder.build_torchair_graph_dummy(3, 3)
|
||||
|
||||
sin_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=runner.dtype,
|
||||
device=runner.device)
|
||||
cos_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=runner.dtype,
|
||||
device=runner.device)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 3)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
@@ -354,6 +369,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 3)
|
||||
assert torch.equal(sin_golden, metadata.decode.sin)
|
||||
assert torch.equal(cos_golden, metadata.decode.cos)
|
||||
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
Reference in New Issue
Block a user