[bugfix] pcp + mtp acl graph bugfix (#4221)
Fix pcp + mtp bug while using acl graph.
While using pcp + mtp, we need to flatten block_table to avoid irregular
attn mask shape, this was done in mla attn_metadata builder, but we
found out that this influences block_table address and leads to
incorrect results while enable acl graph.
To fix this, we enlarge block_table buffer size and flatten block_table
in model_runner prepare_inputs, so this will not influence block_table
address.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -369,6 +369,12 @@ class AscendMLAMetadataBuilder:
|
|||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
num_decodes_flatten = num_decodes * self.decode_threshold
|
||||||
|
block_table = common_attn_metadata.block_table_tensor[:
|
||||||
|
num_decodes_flatten
|
||||||
|
+
|
||||||
|
num_prefills]
|
||||||
|
|
||||||
if num_actual_tokens_pcp_padded is None:
|
if num_actual_tokens_pcp_padded is None:
|
||||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||||
@@ -546,6 +552,9 @@ class AscendMLAMetadataBuilder:
|
|||||||
cos=cos,
|
cos=cos,
|
||||||
pcp_metadata=pcp_metadata,
|
pcp_metadata=pcp_metadata,
|
||||||
)
|
)
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
prefill_metadata.block_table = block_table[
|
||||||
|
num_decodes_flatten:, ...]
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
@@ -556,12 +565,12 @@ class AscendMLAMetadataBuilder:
|
|||||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||||
seq_lens = seq_lens[:num_decodes]
|
seq_lens = seq_lens[:num_decodes]
|
||||||
input_positions = input_positions[:num_decode_tokens]
|
input_positions = input_positions[:num_decode_tokens]
|
||||||
block_table = block_table[:num_decodes, ...]
|
if self.pcp_size > 1:
|
||||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||||
# to avoid irregular spec_attn_mask shape
|
# to avoid irregular spec_attn_mask shape
|
||||||
if self.pcp_size > 1 and self.decode_threshold > 1:
|
block_table = block_table[:num_decodes_flatten, ...]
|
||||||
block_table = block_table.repeat_interleave(
|
else:
|
||||||
self.decode_threshold, dim=0)
|
block_table = block_table[:num_decodes, ...]
|
||||||
seq_lens_list = seq_lens.tolist()
|
seq_lens_list = seq_lens.tolist()
|
||||||
|
|
||||||
if num_computed_tokens_of_pcp_dcp is not None:
|
if num_computed_tokens_of_pcp_dcp is not None:
|
||||||
|
|||||||
@@ -27,13 +27,29 @@ class BlockTable:
|
|||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
kernel_sizes: Union[list[int], None] = None,
|
kernel_sizes: Union[list[int], None] = None,
|
||||||
cp_kv_cache_interleave_size: int = 1):
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
|
num_speculative_tokens: int = 0):
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.device = device
|
self.device = device
|
||||||
self.physical_block_size = block_size
|
self.physical_block_size = block_size
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.pcp_world_size = get_pcp_group(
|
||||||
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
|
self.pcp_rank = get_pcp_group(
|
||||||
|
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.pcp_world_size = 1
|
||||||
|
self.pcp_rank = 0
|
||||||
|
|
||||||
# If kernel_sizes is None or [0], use physical block size (no splitting)
|
# If kernel_sizes is None or [0], use physical block size (no splitting)
|
||||||
if kernel_sizes is None or kernel_sizes == [0]:
|
if kernel_sizes is None or kernel_sizes == [0]:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
@@ -69,13 +85,16 @@ class BlockTable:
|
|||||||
else:
|
else:
|
||||||
logical_table_size = max_num_blocks_per_req
|
logical_table_size = max_num_blocks_per_req
|
||||||
|
|
||||||
|
duplicate_size = 1
|
||||||
|
if self.pcp_world_size > 1:
|
||||||
|
duplicate_size += num_speculative_tokens
|
||||||
self.block_table = torch.zeros(
|
self.block_table = torch.zeros(
|
||||||
(max_num_reqs, logical_table_size),
|
(max_num_reqs * duplicate_size, logical_table_size),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
self.block_table_cpu = torch.zeros(
|
self.block_table_cpu = torch.zeros(
|
||||||
(max_num_reqs, logical_table_size),
|
(max_num_reqs * duplicate_size, logical_table_size),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
@@ -83,20 +102,6 @@ class BlockTable:
|
|||||||
self.block_table_np = self.block_table_cpu.numpy()
|
self.block_table_np = self.block_table_cpu.numpy()
|
||||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
try:
|
|
||||||
self.pcp_world_size = get_pcp_group(
|
|
||||||
).world_size if prefill_context_parallel_enable() else 1
|
|
||||||
self.pcp_rank = get_pcp_group(
|
|
||||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
|
||||||
self.dcp_world_size = get_dcp_group().world_size
|
|
||||||
self.dcp_rank = get_dcp_group().rank_in_group
|
|
||||||
except AssertionError:
|
|
||||||
# DCP might not be initialized in testing
|
|
||||||
self.dcp_world_size = 1
|
|
||||||
self.dcp_rank = 0
|
|
||||||
self.pcp_world_size = 1
|
|
||||||
self.pcp_rank = 0
|
|
||||||
|
|
||||||
self.slot_mapping_cpu = torch.zeros(
|
self.slot_mapping_cpu = torch.zeros(
|
||||||
self.max_num_batched_tokens +
|
self.max_num_batched_tokens +
|
||||||
2 * self.pcp_world_size * self.max_num_reqs,
|
2 * self.pcp_world_size * self.max_num_reqs,
|
||||||
@@ -306,7 +311,7 @@ class MultiGroupBlockTable:
|
|||||||
block_size * dcp_world_size * pcp_world_size),
|
block_size * dcp_world_size * pcp_world_size),
|
||||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||||
pin_memory, device, kernel_size_list,
|
pin_memory, device, kernel_size_list,
|
||||||
cp_kv_cache_interleave_size)
|
cp_kv_cache_interleave_size, num_speculative_tokens)
|
||||||
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
|
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -596,6 +596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.is_pooling_model,
|
self.is_pooling_model,
|
||||||
self.vllm_config.model_config.logits_processors),
|
self.vllm_config.model_config.logits_processors),
|
||||||
is_pooling_model=self.is_pooling_model,
|
is_pooling_model=self.is_pooling_model,
|
||||||
|
num_speculative_tokens=(
|
||||||
|
self.vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if self.vllm_config.speculative_config else 0),
|
||||||
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
||||||
cp_kv_cache_interleave_size=self.parallel_config.
|
cp_kv_cache_interleave_size=self.parallel_config.
|
||||||
cp_kv_cache_interleave_size
|
cp_kv_cache_interleave_size
|
||||||
@@ -1922,6 +1925,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
prefill_context_parallel_metadata=long_seq_metadata,
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.speculative_config and self.pcp_size > 1:
|
||||||
|
# For pcp + spec decode, we flatten block_table
|
||||||
|
# to avoid irregular spec_attn_mask shape, e.g.,
|
||||||
|
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||||
|
# ori block_table: # [d0, d1, p0, p1, p2]
|
||||||
|
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||||
|
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||||
|
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||||
|
ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||||
|
self.query_start_loc_pcp_full_cpu[:num_reqs]
|
||||||
|
num_prefill_reqs = (ori_query_lens
|
||||||
|
> self.decode_threshold).sum().item()
|
||||||
|
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||||
|
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
|
||||||
|
blk_table_tensor[
|
||||||
|
num_decode_reqs_flatten:num_decode_reqs_flatten +
|
||||||
|
num_prefill_reqs].copy_(
|
||||||
|
blk_table_tensor[num_decode_reqs:num_decode_reqs +
|
||||||
|
num_prefill_reqs].clone())
|
||||||
|
blk_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||||
|
blk_table_tensor[:num_decode_reqs].repeat_interleave(
|
||||||
|
self.decode_threshold, dim=0))
|
||||||
|
common_attn_metadata.block_table_tensor = \
|
||||||
|
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
|
||||||
|
|
||||||
if self.speculative_config and \
|
if self.speculative_config and \
|
||||||
self.spec_decode_common_attn_metadata is None:
|
self.spec_decode_common_attn_metadata is None:
|
||||||
self.spec_decode_common_attn_metadata = common_attn_metadata
|
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||||||
@@ -2831,6 +2859,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sin=self.sin,
|
sin=self.sin,
|
||||||
prefill_context_parallel_metadata=long_seq_metadata,
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
)
|
)
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
common_attn_metadata.block_table_tensor = \
|
||||||
|
block_table_tensor[:num_reqs * self.decode_threshold]
|
||||||
attn_state = AscendAttentionState.DecodeOnly
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
if self.speculative_config and \
|
if self.speculative_config and \
|
||||||
self.speculative_config.method == "deepseek_mtp":
|
self.speculative_config.method == "deepseek_mtp":
|
||||||
|
|||||||
Reference in New Issue
Block a user