[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -75,7 +75,7 @@ class BlockTable:
|
||||
logical_table_size = max_num_blocks_per_req
|
||||
|
||||
duplicate_size = 1
|
||||
if self.pcp_world_size > 1:
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
duplicate_size += num_speculative_tokens
|
||||
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
|
||||
logical_table_size,
|
||||
|
||||
@@ -280,7 +280,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.query_start_loc_pcp_full = self._make_buffer(
|
||||
@@ -289,8 +289,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.decode_token_per_req += self.speculative_config.num_speculative_tokens
|
||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.decode_threshold = 1 + (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
@@ -575,6 +576,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.generate_kv_idx(scheduler_output)
|
||||
tokens_before_update = tokens.copy()
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
@@ -591,7 +593,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
for num_tokens, i in zip(tokens, req_ids)
|
||||
for num_tokens, i in zip((tokens_before_update if self.
|
||||
pcp_size > 1 else tokens), req_ids)
|
||||
],
|
||||
dtype=np.int32)
|
||||
|
||||
@@ -909,7 +912,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
|
||||
num_draft_tokens, cu_num_tokens,
|
||||
self.num_pcp_pads[:num_reqs].numpy())
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||
@@ -931,10 +935,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self._generate_pcp_mtp_input(
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
scheduler_output.num_scheduled_tokens, with_prefill,
|
||||
req_indices, positions_np, cu_num_tokens)
|
||||
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens)
|
||||
@@ -1040,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
# to avoid irregular spec_attn_mask shape, e.g.,
|
||||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||
@@ -1048,12 +1053,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
|
||||
self.query_start_loc_pcp_full.cpu[:num_reqs]
|
||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
|
||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
|
||||
num_decode_reqs_flatten = \
|
||||
ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
blk_table_tensor[
|
||||
num_decode_reqs_flatten:num_decode_reqs_flatten +
|
||||
num_prefill_reqs].copy_(
|
||||
@@ -1061,9 +1067,15 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_prefill_reqs].clone())
|
||||
blk_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
blk_table_tensor[:num_decode_reqs].repeat_interleave(
|
||||
self.decode_threshold, dim=0))
|
||||
ori_query_lens[:num_decode_reqs], dim=0))
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
|
||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
if 'pad_size' in locals() and pad_size > 0:
|
||||
ori_query_lens_cpu[-pad_size:] = \
|
||||
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
long_seq_metadata.max_query_len_pcp_full = \
|
||||
ori_query_lens_cpu.max().item()
|
||||
|
||||
if self.speculative_config and \
|
||||
self.spec_decode_common_attn_metadata is None:
|
||||
@@ -1861,7 +1873,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
block_table_tensor[:num_reqs * self.decode_threshold]
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
@@ -3029,9 +3041,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
tokens = np.array(tokens, dtype=np.int32)
|
||||
num_decode_reqs = sum(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||||
self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
|
||||
num_decode_tokens = sum(tokens[:num_decode_reqs])
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
tokens /
|
||||
@@ -3118,8 +3128,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
|
||||
# In dummy run num_reqs == 0, update it from seq_lens
|
||||
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
|
||||
if self.pcp_size > 1 and self.speculative_config else self.query_lens
|
||||
num_decodes = (query_lens <= self.decode_threshold).sum().item()
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
@@ -3137,16 +3149,41 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
# to avoid irregular spec_attn_mask shape.
|
||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||
# and keep prefill seq_lens unchanged.
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_cp_local_seq_lens(
|
||||
torch.tensor(context_lens),
|
||||
torch.tensor(context_lens) - decode_idx,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if num_decodes:
|
||||
num_decodes_flatten = \
|
||||
self.query_lens[:num_decodes].sum().item()
|
||||
if self.query_lens[:num_decodes].min().item(
|
||||
) == self.decode_threshold:
|
||||
decode_flatten_idx = list(range(num_decodes_flatten))
|
||||
else:
|
||||
decode_flatten_idx = []
|
||||
for req_id in range(num_decodes):
|
||||
offset = (req_id + 1) * self.decode_threshold
|
||||
decode_flatten_idx += \
|
||||
list(range(offset - self.query_lens[req_id], offset))
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
||||
if num_prefills:
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
(num_decodes + 1) * self.decode_threshold -
|
||||
1::self.decode_threshold])
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(
|
||||
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
@@ -3278,6 +3315,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
with_prefill: bool = True,
|
||||
req_indices=None,
|
||||
positions_np=None,
|
||||
cu_num_tokens=None,
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
@@ -3288,6 +3329,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
|
||||
num_scheduled_tokens_pcp_full)
|
||||
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens_pcp_full)
|
||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||
@@ -3313,11 +3356,45 @@ class NPUModelRunner(GPUModelRunner):
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full.
|
||||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
self.query_lens_pcp_full.copy_to_gpu()
|
||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
|
||||
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
||||
# For mtpx, pre-allocate mtp slot_mapping here
|
||||
if self.decode_threshold > 2 and not with_prefill:
|
||||
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
||||
num_tokens_mtp = \
|
||||
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
|
||||
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
|
||||
req_indices_split = np.array_split(req_indices,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
positions_split = np.array_split(positions_np,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
for req_idx in range(num_reqs):
|
||||
ori_req_indice = req_indices_split[req_idx]
|
||||
ori_position = positions_split[req_idx]
|
||||
req_indices_split[req_idx] = np.append(
|
||||
ori_req_indice,
|
||||
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
|
||||
positions_split[req_idx] = np.append(
|
||||
ori_position,
|
||||
np.arange(ori_position[-1] + 1,
|
||||
ori_position[-1] + self.decode_threshold - 1))
|
||||
req_indices_mtp = np.concatenate(req_indices_split)
|
||||
positions_mtp = np.concatenate(positions_split)
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices_mtp, positions_mtp)
|
||||
mtp_slot_ori = self.input_batch.block_table.block_tables[
|
||||
0].slot_mapping.cpu[:num_tokens_mtp]
|
||||
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
|
||||
unpad_mask[::self.pcp_size] = True
|
||||
mtp_slot_pad = \
|
||||
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
|
||||
mtp_slot_pad[unpad_mask] = mtp_slot_ori
|
||||
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user