[bugfix](CP) Fix and unify the PD request discrimination logic. (#5939)
### What this PR does / why we need it?
Since the PR (https://github.com/vllm-project/vllm/pull/32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.
### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -20,17 +20,18 @@
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
|
||||||
|
|
||||||
os.environ["HCCL_BUFFSIZE"] = "512"
|
os.environ["HCCL_BUFFSIZE"] = "512"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
def test_pcp_dcp_mtp1_eager():
|
|
||||||
prompts = [
|
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
"The capital of France is", "Hello, my name is Tom, I am",
|
||||||
"The president of United States is", "AI future is"
|
"The president of United States is", "AI future is"
|
||||||
]
|
]
|
||||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
model = "wemaster/deepseek_mtp_main_random_bf16"
|
||||||
|
|
||||||
|
@wait_until_npu_memory_free()
|
||||||
|
def test_pcp_dcp_mtp1_eager():
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@@ -50,15 +51,8 @@ def test_pcp_dcp_mtp1_eager():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@wait_until_npu_memory_free()
|
||||||
reason="vLLM PR-32118 break this",
|
|
||||||
)
|
|
||||||
def test_pcp_dcp_mtp3_eager():
|
def test_pcp_dcp_mtp3_eager():
|
||||||
prompts = [
|
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
|
||||||
"The president of United States is", "AI future is"
|
|
||||||
]
|
|
||||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@@ -78,15 +72,8 @@ def test_pcp_dcp_mtp3_eager():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@wait_until_npu_memory_free()
|
||||||
reason="vLLM PR-32118 break this",
|
|
||||||
)
|
|
||||||
def test_pcp_dcp_mtp3_piecewise_graph():
|
def test_pcp_dcp_mtp3_piecewise_graph():
|
||||||
prompts = [
|
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
|
||||||
"The president of United States is", "AI future is"
|
|
||||||
]
|
|
||||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@@ -109,15 +96,8 @@ def test_pcp_dcp_mtp3_piecewise_graph():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@wait_until_npu_memory_free()
|
||||||
reason="vLLM PR-32118 break this",
|
|
||||||
)
|
|
||||||
def test_pcp_dcp_mtp3_full_graph():
|
def test_pcp_dcp_mtp3_full_graph():
|
||||||
prompts = [
|
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
|
||||||
"The president of United States is", "AI future is"
|
|
||||||
]
|
|
||||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
@@ -140,12 +120,8 @@ def test_pcp_dcp_mtp3_full_graph():
|
|||||||
runner.generate_greedy(prompts, 32)
|
runner.generate_greedy(prompts, 32)
|
||||||
|
|
||||||
|
|
||||||
|
@wait_until_npu_memory_free()
|
||||||
def test_dcp_mtp3_full_graph():
|
def test_dcp_mtp3_full_graph():
|
||||||
prompts = [
|
|
||||||
"The capital of France is", "Hello, my name is Tom, I am",
|
|
||||||
"The president of United States is", "AI future is"
|
|
||||||
]
|
|
||||||
model = "wemaster/deepseek_mtp_main_random_bf16"
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
|
|||||||
@@ -141,8 +141,10 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
|||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
||||||
arange_np = np.arange(10000)
|
arange_np = np.arange(10000)
|
||||||
|
num_scheduled_tokens = np.array(tokens)
|
||||||
|
pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs)
|
||||||
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
||||||
np.array(tokens), arange_np, num_reqs, 1)
|
num_scheduled_tokens, arange_np)
|
||||||
|
|
||||||
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
||||||
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
||||||
@@ -305,8 +307,8 @@ def test_generate_pcp_mtp_input(
|
|||||||
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
||||||
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
||||||
|
|
||||||
pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
|
pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs)
|
||||||
num_scheduled_tokens, False,
|
pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False,
|
||||||
input_batch, arange_np)
|
input_batch, arange_np)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
||||||
|
|||||||
@@ -600,10 +600,16 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
req_indices, positions_np)
|
req_indices, positions_np)
|
||||||
self.input_batch.block_table.commit_slot_mapping(
|
self.input_batch.block_table.commit_slot_mapping(
|
||||||
total_num_scheduled_tokens)
|
total_num_scheduled_tokens)
|
||||||
|
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
self.pcp_manager.init_batch_info(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
self.input_batch.num_reqs,
|
||||||
|
)
|
||||||
|
|
||||||
# for pcp, prefill mtp should use origin scheduleroutput ,
|
# for pcp, prefill mtp should use origin scheduleroutput ,
|
||||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||||
self.pcp_manager.generate_pcp_mtp_input(
|
self.pcp_manager.generate_pcp_mtp_input(
|
||||||
num_reqs,
|
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens,
|
||||||
scheduler_output.num_scheduled_tokens,
|
scheduler_output.num_scheduled_tokens,
|
||||||
with_prefill,
|
with_prefill,
|
||||||
@@ -621,8 +627,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
||||||
num_scheduled_tokens[:num_reqs],
|
num_scheduled_tokens[:num_reqs],
|
||||||
self.arange_np,
|
self.arange_np,
|
||||||
self.input_batch.num_reqs,
|
|
||||||
self.reorder_batch_threshold,
|
|
||||||
)
|
)
|
||||||
# Re-update after PCP split sequences.
|
# Re-update after PCP split sequences.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens)
|
total_num_scheduled_tokens = sum(num_scheduled_tokens)
|
||||||
@@ -772,8 +776,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_draft_tokens = None
|
num_draft_tokens = None
|
||||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
logits_indices = self.pcp_manager.get_logits_indices(
|
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
|
||||||
cu_num_tokens, num_reqs)
|
|
||||||
logits_indices = logits_indices.pin_memory().to(
|
logits_indices = logits_indices.pin_memory().to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
@@ -1020,9 +1023,8 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||||
num_prefill_reqs = (ori_query_lens
|
num_prefill_reqs = self.pcp_manager.num_prefill_reqs
|
||||||
> self.decode_threshold).sum().item()
|
num_decode_reqs = self.pcp_manager.num_decode_reqs
|
||||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
|
||||||
else:
|
else:
|
||||||
long_seq_metadata = None # type: ignore
|
long_seq_metadata = None # type: ignore
|
||||||
num_prefill_reqs = 0
|
num_prefill_reqs = 0
|
||||||
@@ -1976,7 +1978,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
)
|
)
|
||||||
return blk_table_tensor, slot_mapping
|
return blk_table_tensor, slot_mapping
|
||||||
|
|
||||||
long_seq_metdadata = _get_pcp_metadata(num_tokens)
|
self.long_seq_metadata = _get_pcp_metadata(num_tokens)
|
||||||
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
|
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
|
||||||
|
|
||||||
actual_last_loc = self.query_start_loc.np[num_reqs_padded]
|
actual_last_loc = self.query_start_loc.np[num_reqs_padded]
|
||||||
@@ -2008,7 +2010,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
positions=self.positions.gpu,
|
positions=self.positions.gpu,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
prefill_context_parallel_metadata=long_seq_metdadata,
|
prefill_context_parallel_metadata=self.long_seq_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
|
if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill:
|
||||||
@@ -2198,6 +2200,11 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
force_has_lora=activate_lora,
|
force_has_lora=activate_lora,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
self.pcp_manager.init_batch_info(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_reqs,
|
||||||
|
)
|
||||||
if cudagraph_runtime_mode is None:
|
if cudagraph_runtime_mode is None:
|
||||||
cudagraph_runtime_mode = _cudagraph_mode
|
cudagraph_runtime_mode = _cudagraph_mode
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ class PCPManager:
|
|||||||
This manager encapsulates all PCP-related buffers and logic so that the
|
This manager encapsulates all PCP-related buffers and logic so that the
|
||||||
ModelRunner can access them via `self.pcp_manager`.
|
ModelRunner can access them via `self.pcp_manager`.
|
||||||
"""
|
"""
|
||||||
|
num_reqs: int = 0
|
||||||
|
num_decode_reqs: int = 0
|
||||||
|
num_prefill_reqs: int = 0
|
||||||
|
num_decode_tokens: int = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -133,12 +137,25 @@ class PCPManager:
|
|||||||
|
|
||||||
return cu_num_tokens, arange
|
return cu_num_tokens, arange
|
||||||
|
|
||||||
|
def init_batch_info(
|
||||||
|
self,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
num_reqs: int,
|
||||||
|
) -> None:
|
||||||
|
self.num_reqs = num_reqs
|
||||||
|
is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold)
|
||||||
|
if not any(is_prefill):
|
||||||
|
first_prefill = num_reqs
|
||||||
|
else:
|
||||||
|
first_prefill = is_prefill.argmax()
|
||||||
|
self.num_decode_reqs = first_prefill
|
||||||
|
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
|
||||||
|
self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum()
|
||||||
|
|
||||||
def update_tokens_for_pcp(
|
def update_tokens_for_pcp(
|
||||||
self,
|
self,
|
||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
arange_np: np.ndarray,
|
arange_np: np.ndarray,
|
||||||
num_reqs: int,
|
|
||||||
reorder_batch_threshold: int | None = None,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Update token counts and positions for Prefill Context Parallelism (PCP).
|
Update token counts and positions for Prefill Context Parallelism (PCP).
|
||||||
@@ -167,8 +184,6 @@ class PCPManager:
|
|||||||
the number of new tokens scheduled per request.
|
the number of new tokens scheduled per request.
|
||||||
arange_np: 1D numpy array of length max_buffer_num_tokens used for
|
arange_np: 1D numpy array of length max_buffer_num_tokens used for
|
||||||
efficient batched arange operations.
|
efficient batched arange operations.
|
||||||
num_reqs: Total number of requests in the batch.
|
|
||||||
reorder_batch_threshold: Threshold for decode vs prefill requests.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple (pcp_tokens, pcp_positions):
|
Tuple (pcp_tokens, pcp_positions):
|
||||||
@@ -187,16 +202,10 @@ class PCPManager:
|
|||||||
>>> self.pcp_unpad_mask_cpu
|
>>> self.pcp_unpad_mask_cpu
|
||||||
[True, False, True, True, True, True, True, False, False,
|
[True, False, True, True, True, True, True, False, False,
|
||||||
False, True, True, True, True, True, True, True, True]
|
False, True, True, True, True, True, True, True, True]
|
||||||
>>> self.pcp_allgather_resotre_idx
|
>>> self.pcp_allgather_restore_idx
|
||||||
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
|
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert reorder_batch_threshold is not None, (
|
|
||||||
"PCP depends on reorder batch to split decode and prefill requests."
|
|
||||||
)
|
|
||||||
num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold)
|
|
||||||
num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs])
|
|
||||||
|
|
||||||
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
|
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
|
||||||
# We first pad each request's token count up to that multiple.
|
# We first pad each request's token count up to that multiple.
|
||||||
num_padded_scheduled_tokens = np.ceil(
|
num_padded_scheduled_tokens = np.ceil(
|
||||||
@@ -205,11 +214,11 @@ class PCPManager:
|
|||||||
|
|
||||||
# PCP does not split decode requests. For decode requests, we instead
|
# PCP does not split decode requests. For decode requests, we instead
|
||||||
# duplicate the scheduled tokens across the pcp_world_size ranks.
|
# duplicate the scheduled tokens across the pcp_world_size ranks.
|
||||||
num_padded_scheduled_tokens[:num_decode_reqs] = (
|
num_padded_scheduled_tokens[:self.num_decode_reqs] = (
|
||||||
num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size)
|
num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size)
|
||||||
|
|
||||||
# Record how many pads were added per request (padded - original).
|
# Record how many pads were added per request (padded - original).
|
||||||
self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens -
|
self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens -
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
|
|
||||||
# cu_padded_tokens: cumulative sum of padded token counts,
|
# cu_padded_tokens: cumulative sum of padded token counts,
|
||||||
@@ -221,7 +230,7 @@ class PCPManager:
|
|||||||
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
|
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
|
||||||
pcp_padded_arange < np.repeat(num_scheduled_tokens,
|
pcp_padded_arange < np.repeat(num_scheduled_tokens,
|
||||||
num_padded_scheduled_tokens))
|
num_padded_scheduled_tokens))
|
||||||
unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens *
|
unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens *
|
||||||
self.pcp_world_size]
|
self.pcp_world_size]
|
||||||
unpad_mask_decode = unpad_mask_decode.reshape(
|
unpad_mask_decode = unpad_mask_decode.reshape(
|
||||||
[-1, self.pcp_world_size])
|
[-1, self.pcp_world_size])
|
||||||
@@ -233,7 +242,7 @@ class PCPManager:
|
|||||||
# For prefill requests, we further split the pcp_tokens into two chunks
|
# For prefill requests, we further split the pcp_tokens into two chunks
|
||||||
# (head and tail). For decode requests, the chunk equals pcp_tokens.
|
# (head and tail). For decode requests, the chunk equals pcp_tokens.
|
||||||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||||||
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
|
pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs]
|
||||||
|
|
||||||
# Build arange-style helpers for pcp tokens and chunk sizes:
|
# Build arange-style helpers for pcp tokens and chunk sizes:
|
||||||
# - pcp_arange gives indices repeated for each token in pcp_tokens
|
# - pcp_arange gives indices repeated for each token in pcp_tokens
|
||||||
@@ -271,16 +280,16 @@ class PCPManager:
|
|||||||
# Fill tail positions. Note decode requests do not have tail chunks,
|
# Fill tail positions. Note decode requests do not have tail chunks,
|
||||||
# so the tail filling is only for prefill positions.
|
# so the tail filling is only for prefill positions.
|
||||||
positions[~pcp_head_chunk_mask] = (
|
positions[~pcp_head_chunk_mask] = (
|
||||||
pcp_chunk_arange[num_decode_tokens:] +
|
pcp_chunk_arange[self.num_decode_tokens:] +
|
||||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:])
|
np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:])
|
||||||
return positions
|
return positions
|
||||||
|
|
||||||
positions = get_current_rank_positions(0, self.pcp_world_rank)
|
positions = get_current_rank_positions(0, self.pcp_world_rank)
|
||||||
# Decode tokens are duplicated only after AG. But their positions are
|
# Decode tokens are duplicated only after AG. But their positions are
|
||||||
# same without prefill context parallel.
|
# same without prefill context parallel.
|
||||||
if num_decode_reqs > 0:
|
if self.num_decode_reqs > 0:
|
||||||
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
|
positions[:self.num_decode_tokens] = self._get_cumsum_and_arange(
|
||||||
num_scheduled_tokens[:num_decode_reqs], arange_np)[1]
|
num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1]
|
||||||
|
|
||||||
# Build the restore index used after allgather.
|
# Build the restore index used after allgather.
|
||||||
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
|
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
|
||||||
@@ -294,27 +303,16 @@ class PCPManager:
|
|||||||
all_positions.argsort())
|
all_positions.argsort())
|
||||||
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
|
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
|
||||||
|
|
||||||
self.pcp_tokens[:num_reqs] = pcp_tokens[:num_reqs]
|
self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs]
|
||||||
self.total_num_sampled_tokens_pcp = pcp_tokens[:num_reqs].sum()
|
self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum()
|
||||||
return (
|
return (
|
||||||
pcp_tokens[:num_reqs],
|
pcp_tokens[:self.num_reqs],
|
||||||
positions,
|
positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int):
|
def get_logits_indices(self, cu_num_tokens: np.ndarray):
|
||||||
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
|
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
|
||||||
self.num_pcp_pads_cpu_tensor[:num_reqs] - 1)
|
self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1)
|
||||||
|
|
||||||
def get_discard_request_mask(
|
|
||||||
self,
|
|
||||||
num_computed_tokens_cpu: np.ndarray,
|
|
||||||
num_scheduled_tokens: np.ndarray,
|
|
||||||
num_reqs: int,
|
|
||||||
num_tokens_np: np.ndarray,
|
|
||||||
):
|
|
||||||
return (num_computed_tokens_cpu[:num_reqs] +
|
|
||||||
num_scheduled_tokens * self.pcp_world_size -
|
|
||||||
self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np
|
|
||||||
|
|
||||||
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int,
|
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int,
|
||||||
slot_mapping: torch.Tensor):
|
slot_mapping: torch.Tensor):
|
||||||
@@ -350,7 +348,6 @@ class PCPManager:
|
|||||||
|
|
||||||
def generate_pcp_mtp_input(
|
def generate_pcp_mtp_input(
|
||||||
self,
|
self,
|
||||||
num_reqs: int,
|
|
||||||
total_num_scheduled_tokens: int,
|
total_num_scheduled_tokens: int,
|
||||||
num_scheduled_tokens: dict[str, int],
|
num_scheduled_tokens: dict[str, int],
|
||||||
with_prefill: bool = True,
|
with_prefill: bool = True,
|
||||||
@@ -369,18 +366,18 @@ class PCPManager:
|
|||||||
so we record original input_ids here.
|
so we record original input_ids here.
|
||||||
"""
|
"""
|
||||||
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||||
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
|
self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy(
|
||||||
num_scheduled_tokens_pcp_full)
|
num_scheduled_tokens_pcp_full)
|
||||||
req_indices_pcp_full = np.repeat(arange_np[:num_reqs],
|
req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs],
|
||||||
num_scheduled_tokens_pcp_full)
|
num_scheduled_tokens_pcp_full)
|
||||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||||
self.query_start_loc_pcp_full.np[0] = 0
|
self.query_start_loc_pcp_full.np[0] = 0
|
||||||
self.query_start_loc_pcp_full.np[1:num_reqs +
|
self.query_start_loc_pcp_full.np[1:self.num_reqs +
|
||||||
1] = cu_num_tokens_pcp_full
|
1] = cu_num_tokens_pcp_full
|
||||||
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
|
self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1)
|
||||||
cumsums_offsets_pcp_full = np.repeat(
|
cumsums_offsets_pcp_full = np.repeat(
|
||||||
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||||||
num_scheduled_tokens_pcp_full)
|
num_scheduled_tokens_pcp_full)
|
||||||
@@ -413,13 +410,13 @@ class PCPManager:
|
|||||||
if self.decode_threshold > 2 and not with_prefill:
|
if self.decode_threshold > 2 and not with_prefill:
|
||||||
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
||||||
num_tokens_mtp = \
|
num_tokens_mtp = \
|
||||||
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
|
num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
|
||||||
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
|
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
|
||||||
req_indices_split = np.array_split(req_indices,
|
req_indices_split = np.array_split(req_indices,
|
||||||
cu_num_tokens)[:num_reqs]
|
cu_num_tokens)[:self.num_reqs]
|
||||||
positions_split = np.array_split(positions_np,
|
positions_split = np.array_split(positions_np,
|
||||||
cu_num_tokens)[:num_reqs]
|
cu_num_tokens)[:self.num_reqs]
|
||||||
for req_idx in range(num_reqs):
|
for req_idx in range(self.num_reqs):
|
||||||
ori_req_indice = req_indices_split[req_idx]
|
ori_req_indice = req_indices_split[req_idx]
|
||||||
ori_position = positions_split[req_idx]
|
ori_position = positions_split[req_idx]
|
||||||
req_indices_split[req_idx] = np.append(
|
req_indices_split[req_idx] = np.append(
|
||||||
@@ -567,25 +564,20 @@ class PCPManager:
|
|||||||
input_batch, num_scheduled_tokens):
|
input_batch, num_scheduled_tokens):
|
||||||
from vllm_ascend.attention.utils import \
|
from vllm_ascend.attention.utils import \
|
||||||
AscendPrefillContextParallelMetadata
|
AscendPrefillContextParallelMetadata
|
||||||
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
|
||||||
query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \
|
|
||||||
if self.pcp_world_size > 1 and self.speculative_config else query_lens
|
|
||||||
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
|
|
||||||
num_prefills = num_reqs - num_decodes
|
|
||||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
||||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||||
long_seq_metadata = None
|
long_seq_metadata = None
|
||||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||||
decode_context_lens = input_batch.num_computed_tokens_cpu[:
|
decode_context_lens = input_batch.num_computed_tokens_cpu[:
|
||||||
num_decodes] + num_scheduled_tokens[:
|
self.num_decode_reqs] + num_scheduled_tokens[:
|
||||||
num_decodes]
|
self.num_decode_reqs]
|
||||||
prefill_context_lens = input_batch.num_computed_tokens_cpu[
|
prefill_context_lens = input_batch.num_computed_tokens_cpu[
|
||||||
num_decodes:num_reqs]
|
self.num_decode_reqs:self.num_reqs]
|
||||||
context_lens = np.concatenate(
|
context_lens = np.concatenate(
|
||||||
[decode_context_lens, prefill_context_lens])
|
[decode_context_lens, prefill_context_lens])
|
||||||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||||||
[
|
[
|
||||||
num_reqs * self.decode_threshold, self.pcp_world_size,
|
self.num_reqs * self.decode_threshold, self.pcp_world_size,
|
||||||
self.dcp_world_size
|
self.dcp_world_size
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -605,24 +597,24 @@ class PCPManager:
|
|||||||
)
|
)
|
||||||
if self.decode_threshold > 1:
|
if self.decode_threshold > 1:
|
||||||
num_computed_tokens_of_pcp_dcp_list = []
|
num_computed_tokens_of_pcp_dcp_list = []
|
||||||
if num_decodes:
|
if self.num_decode_reqs:
|
||||||
num_decodes_flatten = \
|
num_decodes_flatten = \
|
||||||
query_lens[:num_decodes].sum().item()
|
query_lens[:self.num_decode_reqs].sum().item()
|
||||||
if query_lens[:num_decodes].min().item(
|
if query_lens[:self.num_decode_reqs].min().item(
|
||||||
) == self.decode_threshold:
|
) == self.decode_threshold:
|
||||||
decode_flatten_idx = list(range(num_decodes_flatten))
|
decode_flatten_idx = list(range(num_decodes_flatten))
|
||||||
else:
|
else:
|
||||||
decode_flatten_idx = []
|
decode_flatten_idx = []
|
||||||
for req_id in range(num_decodes):
|
for req_id in range(self.num_decode_reqs):
|
||||||
offset = (req_id + 1) * self.decode_threshold
|
offset = (req_id + 1) * self.decode_threshold
|
||||||
decode_flatten_idx += \
|
decode_flatten_idx += \
|
||||||
list(range(offset - query_lens[req_id], offset))
|
list(range(offset - query_lens[req_id], offset))
|
||||||
num_computed_tokens_of_pcp_dcp_list.append(
|
num_computed_tokens_of_pcp_dcp_list.append(
|
||||||
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
||||||
if num_prefills:
|
if self.num_prefill_reqs:
|
||||||
num_computed_tokens_of_pcp_dcp_list.append(
|
num_computed_tokens_of_pcp_dcp_list.append(
|
||||||
num_computed_tokens_of_pcp_dcp[
|
num_computed_tokens_of_pcp_dcp[
|
||||||
(num_decodes + 1) * self.decode_threshold -
|
(self.num_decode_reqs + 1) * self.decode_threshold -
|
||||||
1::self.decode_threshold])
|
1::self.decode_threshold])
|
||||||
num_computed_tokens_of_pcp_dcp = torch.cat(
|
num_computed_tokens_of_pcp_dcp = torch.cat(
|
||||||
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||||
@@ -643,7 +635,7 @@ class PCPManager:
|
|||||||
q_head_chunk_id = self.pcp_world_rank
|
q_head_chunk_id = self.pcp_world_rank
|
||||||
q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank
|
q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank
|
||||||
for i, seq_len in enumerate(query_lens):
|
for i, seq_len in enumerate(query_lens):
|
||||||
if i < num_decodes:
|
if i < self.num_decode_reqs:
|
||||||
continue
|
continue
|
||||||
chunk_len = seq_len // 2
|
chunk_len = seq_len // 2
|
||||||
chunk_seqlens.append(chunk_len)
|
chunk_seqlens.append(chunk_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user