[bugfix] pcp + mtp acl graph bugfix (#4221)
Fix pcp + mtp bug while using acl graph.
While using pcp + mtp, we need to flatten block_table to avoid irregular
attn mask shape, this was done in mla attn_metadata builder, but we
found out that this influences block_table address and leads to
incorrect results while enable acl graph.
To fix this, we enlarge block_table buffer size and flatten block_table
in model_runner prepare_inputs, so this will not influence block_table
address.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -27,13 +27,29 @@ class BlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_sizes: Union[list[int], None] = None,
|
||||
cp_kv_cache_interleave_size: int = 1):
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
num_speculative_tokens: int = 0):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
self.physical_block_size = block_size
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
|
||||
# If kernel_sizes is None or [0], use physical block size (no splitting)
|
||||
if kernel_sizes is None or kernel_sizes == [0]:
|
||||
self.block_size = block_size
|
||||
@@ -69,13 +85,16 @@ class BlockTable:
|
||||
else:
|
||||
logical_table_size = max_num_blocks_per_req
|
||||
|
||||
duplicate_size = 1
|
||||
if self.pcp_world_size > 1:
|
||||
duplicate_size += num_speculative_tokens
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, logical_table_size),
|
||||
(max_num_reqs * duplicate_size, logical_table_size),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, logical_table_size),
|
||||
(max_num_reqs * duplicate_size, logical_table_size),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
@@ -83,20 +102,6 @@ class BlockTable:
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(
|
||||
self.max_num_batched_tokens +
|
||||
2 * self.pcp_world_size * self.max_num_reqs,
|
||||
@@ -306,7 +311,7 @@ class MultiGroupBlockTable:
|
||||
block_size * dcp_world_size * pcp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device, kernel_size_list,
|
||||
cp_kv_cache_interleave_size)
|
||||
cp_kv_cache_interleave_size, num_speculative_tokens)
|
||||
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user