[feature] support pcp + mtp (in pd co-locate scenario) (#4098)
1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -487,19 +487,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
device=self.device)
|
||||
self.input_ids_pcp_full_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
|
||||
)
|
||||
device=self.device)
|
||||
self.query_start_loc_pcp_full_cpu = \
|
||||
torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_pcp_full_np = \
|
||||
self.query_start_loc_pcp_full_cpu.numpy()
|
||||
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
|
||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||
self.decode_threshold = 1 + (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
|
||||
self.use_aclgraph = self._use_aclgraph()
|
||||
self.aclgraph_batch_sizes = list(
|
||||
@@ -1854,8 +1864,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||
self.device, non_blocking=True)
|
||||
else:
|
||||
# pcp not supported now
|
||||
assert self.pcp_size == 1
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
@@ -1866,11 +1874,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||
self.num_draft_tokens.copy_to_gpu()
|
||||
# save logits_indices for pcp spec decode usage
|
||||
self.logits_indices = logits_indices
|
||||
|
||||
# Used in the below loop.
|
||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
@@ -1883,8 +1893,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||
if self.speculative_config and self.pcp_size > 1 and is_prefill:
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
self._generate_pcp_mtp_input(
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
@@ -2040,8 +2049,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
maybe_padded_num_tokens,
|
||||
self.speculative_config)
|
||||
maybe_padded_num_tokens)
|
||||
else:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
@@ -2110,6 +2118,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_draft_tokens: np.ndarray,
|
||||
cu_num_scheduled_tokens: np.ndarray,
|
||||
num_pcp_pads: np.ndarray,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Inputs:
|
||||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||||
@@ -2138,6 +2147,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
logits_indices += arange
|
||||
|
||||
# while pcp > 1, decode results may contain padding (from pcp all-gather),
|
||||
# update logits_indices after getting draft_token_ids from ori logits_indices
|
||||
if self.pcp_size > 1:
|
||||
cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads
|
||||
logits_indices_pcp = np.repeat(
|
||||
cu_num_scheduled_tokens - num_sampled_tokens,
|
||||
num_sampled_tokens)
|
||||
logits_indices_pcp += arange
|
||||
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
|
||||
@@ -2173,6 +2193,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
draft_token_ids = self.input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
if self.pcp_size > 1:
|
||||
logits_indices = logits_indices_pcp
|
||||
if vllm_version_is("0.11.0"):
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
@@ -2920,8 +2942,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context,
|
||||
positions.shape[0],
|
||||
self.speculative_config)
|
||||
positions.shape[0])
|
||||
else:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
@@ -4328,18 +4349,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_decode_reqs = sum(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||||
self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_decode_tokens = sum(tokens[:num_decode_reqs])
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
tokens /
|
||||
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
|
||||
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size
|
||||
num_padded_scheduled_tokens[:num_decode_reqs] = (
|
||||
tokens[:num_decode_reqs] * self.pcp_size)
|
||||
self.num_pcp_pads = num_padded_scheduled_tokens - tokens
|
||||
cu_padded_tokens, pcp_padded_arange = \
|
||||
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
|
||||
unpad_mask = torch.from_numpy(
|
||||
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
|
||||
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
|
||||
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
|
||||
unpad_mask_decode[:, 0] = True
|
||||
unpad_mask_decode[:, 1:] = False
|
||||
|
||||
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
|
||||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||||
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
|
||||
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
|
||||
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
|
||||
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
|
||||
@@ -4356,14 +4384,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
np.repeat(head_start_loc, pcp_chunk_sizes)
|
||||
# Decode reqs do not have tail chunks.
|
||||
positions[~pcp_head_chunk_mask] = \
|
||||
pcp_chunk_arange[num_decode_reqs:] + \
|
||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
|
||||
pcp_chunk_arange[num_decode_tokens:] + \
|
||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
|
||||
return positions
|
||||
|
||||
positions = get_current_rank_positions(
|
||||
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
|
||||
# Decode tokens are duplicate and their positions always be 0.
|
||||
positions[:num_decode_reqs] = 0
|
||||
if num_decode_reqs > 0:
|
||||
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
|
||||
tokens[:num_decode_reqs])[1]
|
||||
|
||||
all_positions = [
|
||||
get_current_rank_positions(cu_padded_tokens, rank_i)
|
||||
@@ -4372,7 +4402,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
|
||||
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
|
||||
all_positions_tensor.float().argsort().long(), non_blocking=True)
|
||||
pcp_tokens[:num_decode_reqs] = 1
|
||||
return pcp_tokens, positions, unpad_mask
|
||||
|
||||
def _get_pcp_local_seq_lens(
|
||||
@@ -4524,7 +4553,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
|
||||
seq_lens_origin):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# In dummy run num_reqs == 0, update it from seq_lens
|
||||
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
@@ -4535,14 +4565,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
local_chunked_kv_lens)
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||||
[
|
||||
num_reqs * self.decode_threshold, self.pcp_size,
|
||||
self.dcp_size
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_pcp_local_seq_lens(
|
||||
seq_lens_origin - decode_idx,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
|
||||
seq_lens_origin,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
).numpy(),
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
numpy(),
|
||||
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||
max_chunk_num=max_chunk_num)
|
||||
@@ -4706,16 +4750,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens_pcp_full)
|
||||
arange_pcp_full = self.arange_np[:
|
||||
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||||
positions_np_pcp_full = self.positions_np_pcp_full[:
|
||||
positions_pcp_full_np = self.positions_pcp_full_np[:
|
||||
total_num_scheduled_tokens_pcp_full]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||||
arange_pcp_full,
|
||||
out=positions_np_pcp_full)
|
||||
out=positions_pcp_full_np)
|
||||
token_indices_pcp_full = (
|
||||
positions_np_pcp_full +
|
||||
positions_pcp_full_np +
|
||||
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
|
||||
torch.index_select(
|
||||
self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])
|
||||
out=self.
|
||||
input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
self.query_start_loc_pcp_full[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_pcp_full_cpu[:num_reqs + 1],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full].copy_(
|
||||
self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user