[Perf] Avoid performing index selection of sin/cos cache every layer (#1890)
Optimize number of index selections of sin/cos cache.
- vLLM version: v0.10.0
- vLLM main:
656c24f1b5
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -331,15 +331,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
runner.chunked_prefill_enabled = False
|
||||
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
|
||||
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
|
||||
runner.dtype = torch.float16
|
||||
|
||||
builder = AscendMLAMetadataBuilder(runner=runner,
|
||||
metadata_cls=AscendMLAMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
metadata = builder.build_torchair_graph_dummy(3, 3)
|
||||
|
||||
sin_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=runner.dtype,
|
||||
device=runner.device)
|
||||
cos_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=runner.dtype,
|
||||
device=runner.device)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 3)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
@@ -354,6 +369,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 3)
|
||||
assert torch.equal(sin_golden, metadata.decode.sin)
|
||||
assert torch.equal(cos_golden, metadata.decode.cos)
|
||||
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@@ -80,6 +80,8 @@ class AscendMLAPrefillMetadata:
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -92,6 +94,8 @@ class AscendMLADecodeMetadata:
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -200,6 +204,9 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -318,13 +325,27 @@ class AscendMLAMetadataBuilder:
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sin = torch.ones(num_reqs,
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
device=device)
|
||||
cos = torch.ones(num_reqs,
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
device=device)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1,
|
||||
attn_mask=self.runner.spec_attn_mask)
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -370,6 +391,16 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = self.runner.get_model(
|
||||
).model.layers[0].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = self.runner.get_model(
|
||||
).model.layers[0].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.runner.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.runner.dtype) # type: ignore
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
@@ -415,18 +446,26 @@ class AscendMLAMetadataBuilder:
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
seq_lens=seq_lens,
|
||||
context_lens=seq_lens[tokens_start:],
|
||||
input_positions=input_positions[tokens_start:],
|
||||
input_positions=prefill_input_positions,
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
@@ -467,6 +506,10 @@ class AscendMLAMetadataBuilder:
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
@@ -474,7 +517,9 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=self.runner.spec_attn_mask)
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -1069,15 +1114,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_k_nope = None
|
||||
assert attn_metadata.decode is not None
|
||||
if self.running_in_graph:
|
||||
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||
dtype=decode_hs_or_q_c.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||
dtype=decode_hs_or_q_c.dtype)
|
||||
cos = cos[attn_metadata.decode.input_positions]
|
||||
sin = sin[attn_metadata.decode.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=enable_multistream_mla):
|
||||
@@ -1124,15 +1162,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
if self.torchair_graph_enabled:
|
||||
num_tokens = prefill_hs_or_q_c.shape[0]
|
||||
seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||
dtype=prefill_q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||
dtype=prefill_q_pe.dtype)
|
||||
cos = cos[attn_metadata.prefill.input_positions]
|
||||
sin = sin[attn_metadata.prefill.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
||||
|
||||
@@ -1799,6 +1799,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(
|
||||
get_forward_context().mc2_mask)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(
|
||||
|
||||
Reference in New Issue
Block a user