[feature]dcp&pcp support mlapo (#5672)
### What this PR does / why we need it?
mlapo in deepseek is a huge performance improvement in decode, this pr
support pcp & dcp with mlapo
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
This commit is contained in:
@@ -70,6 +70,26 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
metadata_cls = super().build(common_prefix_len, common_attn_metadata)
|
||||
if self.num_prefills == 0 and self.pcp_size > 1:
|
||||
self.slot_mapping[:self.
|
||||
num_decode_tokens] = self.slot_mapping[:self.
|
||||
num_decode_tokens
|
||||
* self.
|
||||
pcp_size:
|
||||
self.
|
||||
pcp_size]
|
||||
self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens *
|
||||
self.pcp_size].fill_(-1)
|
||||
metadata_cls.slot_mapping = self.slot_mapping
|
||||
return metadata_cls
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendMlaCPMetadataBuilder"],
|
||||
@@ -363,8 +383,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
|
||||
decode_ql_nope, decode_q_pe)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
|
||||
self.pcp_size:self.pcp_size]
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
|
||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||
|
||||
@@ -438,7 +438,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
if self.num_decodes > 0:
|
||||
decode_metadata = self.build_decode_metadata(
|
||||
common_prefix_len, common_attn_metadata)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
@@ -1330,7 +1329,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.W_UK_T,
|
||||
decode_k_nope,
|
||||
decode_k_pe,
|
||||
attn_metadata.slot_mapping[:bsz].flatten(),
|
||||
attn_metadata.slot_mapping[:bsz],
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
|
||||
Reference in New Issue
Block a user