[bugfix] pcp + mtp acl graph bugfix (#4221)
Fix pcp + mtp bug while using acl graph.
While using pcp + mtp, we need to flatten block_table to avoid irregular
attn mask shape, this was done in mla attn_metadata builder, but we
found out that this influences block_table address and leads to
incorrect results while enable acl graph.
To fix this, we enlarge block_table buffer size and flatten block_table
in model_runner prepare_inputs, so this will not influence block_table
address.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -369,6 +369,12 @@ class AscendMLAMetadataBuilder:
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
if self.pcp_size > 1:
|
||||
num_decodes_flatten = num_decodes * self.decode_threshold
|
||||
block_table = common_attn_metadata.block_table_tensor[:
|
||||
num_decodes_flatten
|
||||
+
|
||||
num_prefills]
|
||||
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
@@ -546,6 +552,9 @@ class AscendMLAMetadataBuilder:
|
||||
cos=cos,
|
||||
pcp_metadata=pcp_metadata,
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
prefill_metadata.block_table = block_table[
|
||||
num_decodes_flatten:, ...]
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@@ -556,12 +565,12 @@ class AscendMLAMetadataBuilder:
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
if self.pcp_size > 1 and self.decode_threshold > 1:
|
||||
block_table = block_table.repeat_interleave(
|
||||
self.decode_threshold, dim=0)
|
||||
if self.pcp_size > 1:
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
block_table = block_table[:num_decodes_flatten, ...]
|
||||
else:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
if num_computed_tokens_of_pcp_dcp is not None:
|
||||
|
||||
Reference in New Issue
Block a user