[Feature] Refactor PCP &DCP related code (#5214)

### What this PR does / why we need it?
Refactor pcp& dcp related code. we use pcp_manager class to Unifiy
Manage pcp & dcp . as we do this , many code can be deleted from
model_runner, and can avoid break pcp & dcp by other developments.
RFC:https://github.com/vllm-project/vllm-ascend/issues/5449
### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-31 09:29:57 +08:00
committed by GitHub
parent 46862ce1af
commit 5d9fde9819
7 changed files with 1156 additions and 1047 deletions

View File

@@ -225,21 +225,22 @@ class TestMtpProposer:
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.dcp_size = 1
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
mock_deps.runner.pcp_manager = MagicMock()
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
32,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.input_ids_pcp_full.cpu = \
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
torch.arange(32, dtype=torch.int32)
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
5,
dtype=torch.int32,
pin_memory=False,
device='cpu',
)
mock_deps.runner.query_start_loc_pcp_full.cpu = \
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
torch.tensor([0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)

View File

@@ -1,473 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@pytest.mark.parametrize(
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
[
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
])
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
num_decodes, use_mla, total_tokens,
expect_not_none):
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.pcp_size = pcp_size
mock_runner.dcp_size = dcp_size
mock_runner.decode_threshold = 4
mock_runner.pcp_rank = 0
mock_runner.device = torch.device('cpu')
mock_runner.dtype = torch.float32
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.cp_kv_cache_interleave_size = 64
mock_runner.vllm_config = MagicMock()
mock_runner.vllm_config.model_config = MagicMock()
mock_runner.vllm_config.model_config.use_mla = use_mla
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = num_reqs
mock_runner.speculative_config = None
num_computed_tokens = []
num_prompt_tokens = []
num_tokens = []
for i in range(num_reqs):
if i < num_decodes:
num_computed_tokens.append(query_lens[i])
num_prompt_tokens.append(query_lens[i] // 2)
num_tokens.append(query_lens[i])
else:
num_computed_tokens.append(0)
num_prompt_tokens.append(query_lens[i])
num_tokens.append(query_lens[i])
mock_runner.input_batch.num_computed_tokens_cpu = torch.tensor(
num_computed_tokens)
mock_runner.input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
mock_runner.input_batch.num_tokens = torch.tensor(num_tokens)
mock_runner.query_lens = torch.tensor(query_lens)
mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__(
mock_runner, NPUModelRunner)
mock_runner.pcp_allgather_restore_idx = torch.arange(total_tokens * 2)
mock_runner.cp_kv_recover_idx_for_chunk = torch.arange(total_tokens)
mock_runner.long_seq_metadata = None
mock_runner.num_actual_tokens_pcp_padded = 0
mock_runner.kv_idx_names = {}
mock_runner.extra_long_seq_kwargs = {}
mock_runner.attn_mask = None
mock_runner.q_head_idx_tensor = None
mock_runner.q_tail_idx_tensor = None
mock_runner.q_full_idx = None
method = NPUModelRunner._generate_pcp_metadata.__get__(
mock_runner, NPUModelRunner)
result = method(total_tokens)
if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"
else:
assert result is not None, "Expected to return a metadata object, but got None."
assert hasattr(result, 'num_actual_tokens_pcp_padded')
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
if pcp_size > 1:
assert hasattr(result, 'pcp_allgather_restore_idx')
has_prefill_requests = (num_reqs - num_decodes) > 0
if has_prefill_requests:
assert hasattr(result, 'q_head_idx_tensor')
assert hasattr(result, 'q_tail_idx_tensor')
assert hasattr(result, 'q_full_idx')
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
assert hasattr(result, 'attn_mask_seqlens')
assert hasattr(result, 'head_attn_nomask_seqlens')
assert hasattr(result, 'tail_attn_nomask_seqlens')
if hasattr(result, 'pcp_prefill_mask'
) and result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048, 2048)
else:
if hasattr(result, 'pcp_prefill_mask'):
if result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048,
2048)
def test_generate_pcp_metadata_edge_cases():
mock_runner = MagicMock()
mock_runner.pcp_size = 2
mock_runner.dcp_size = 1
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = 0
mock_runner.query_lens = torch.tensor([10, 20, 30])
assert (mock_runner.input_batch.num_reqs
or mock_runner.query_lens.size(0)) == 3
mock_runner.input_batch.num_reqs = 100
mock_runner.query_lens = torch.ones(100) * 1000
for rank in [0, 1]:
mock_runner.pcp_rank = rank
q_head_chunk_id = rank
q_tail_chunk_id = 2 * 2 - 1 - rank
assert q_head_chunk_id == rank
assert q_tail_chunk_id == 3 - rank
def test_pcp_allgather_restore_idx_slicing():
mock_runner = MagicMock()
mock_runner.pcp_size = 2
mock_runner.pcp_allgather_restore_idx = torch.arange(1000)
total_num_scheduled_tokens = 200
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * 2
expected_slice = mock_runner.pcp_allgather_restore_idx[:
num_actual_tokens_pcp_padded]
assert len(expected_slice) == 400
assert expected_slice[0] == 0
assert expected_slice[-1] == 399
@pytest.mark.parametrize(
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \
"pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens",
[
# Case 1: prefill only
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]),
# Case 2: mix prefill and decode (with spec decode)
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]),
# Case 3: request which need to be padded
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]),
# Case 4: single request
([10], 1, [0], [10], 4, 0, 1, [4]),
])
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
num_prompt_tokens, pcp_size, pcp_rank,
decode_threshold, expected_pcp_tokens):
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.pcp_size = pcp_size
mock_runner.pcp_rank = pcp_rank
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = num_reqs
mock_runner.input_batch.num_computed_tokens_cpu = np.array(
num_computed_tokens, dtype=np.int32)
mock_runner.input_batch.num_prompt_tokens = np.array(num_prompt_tokens,
dtype=np.int32)
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
mock_runner.num_pcp_pads = [0] * num_reqs
mock_runner.arange_np = np.arange(10000)
mock_runner.decode_threshold = decode_threshold
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
mock_runner, NPUModelRunner)
pcp_tokens_result, positions_result, unpad_mask_result = mock_runner._update_tokens_for_pcp(
tokens)
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
total_pcp_tokens: int = np.sum(pcp_tokens_result)
assert positions_result.shape == (total_pcp_tokens,), \
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
padded_tokens = [
(t + 2 * pcp_size - 1) // (2 * pcp_size) *
(2 * pcp_size) if num_computed_tokens[i] == 0 else t * pcp_size
for i, t in enumerate(tokens)
]
total_padded_tokens: int = np.sum(padded_tokens)
assert unpad_mask_result.shape[0] == total_padded_tokens, \
f"unpad_mask size mismatch: expected {total_padded_tokens}, got {unpad_mask_result.shape[0]}"
def test_update_tokens_for_pcp_with_padding():
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.pcp_size = 4
mock_runner.pcp_rank = 0
mock_runner.arange_np = np.arange(10000)
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = 3
mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0, 0],
dtype=np.int32)
mock_runner.input_batch.num_prompt_tokens = np.array([5, 9, 13],
dtype=np.int32)
mock_runner.num_pcp_pads = [0, 0, 0]
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
mock_runner.decode_threshold = 1
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
mock_runner, NPUModelRunner)
tokens = [5, 9, 13]
pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp(
tokens)
expected_pcp_tokens = [2, 4, 4]
assert np.array_equal(pcp_tokens, expected_pcp_tokens), \
f"Expected {expected_pcp_tokens}, got {pcp_tokens}"
expected_pads = [3, 7, 3]
assert np.array_equal(mock_runner.num_pcp_pads, expected_pads), \
f"Expected padding {expected_pads}, got {mock_runner.num_pcp_pads}"
def test_update_tokens_for_pcp_unpad_mask():
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.pcp_size = 4
mock_runner.pcp_rank = 0
mock_runner.arange_np = np.arange(10000)
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_reqs = 2
mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0],
dtype=np.int32)
mock_runner.input_batch.num_prompt_tokens = np.array([5, 7],
dtype=np.int32)
mock_runner.num_pcp_pads = [0, 0]
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
mock_runner.decode_threshold = 1
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
mock_runner, NPUModelRunner)
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
mock_runner, NPUModelRunner)
tokens = [5, 7]
pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp(
tokens)
assert unpad_mask.dtype == torch.bool, \
f"unpad_mask should be bool, got {unpad_mask.dtype}"
padded_tokens = [8, 8]
expected_length = sum(padded_tokens)
assert unpad_mask.shape[0] == expected_length, \
f"unpad_mask length mismatch: expected {expected_length}, got {unpad_mask.shape[0]}"
expected_mask = [True] * 5 + [False] * 3 + [True] * 7 + [False] * 1
actual_mask = unpad_mask.numpy().tolist()
assert actual_mask == expected_mask, \
f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}"
# yapf: disable
@pytest.mark.parametrize(
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
[
# without pcp and dcp
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
# pcp
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
# dcp
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
# pcp + dcp
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
# specify interleave_size
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
]
)
# yapf: enable
def test_get_cp_local_seq_lens(
seq_lens,
pcp_world_size,
dcp_world_size,
cp_kv_cache_interleave_size,
target,
):
mock_runner = MagicMock(spec=NPUModelRunner)
ret = NPUModelRunner._get_cp_local_seq_lens(mock_runner, seq_lens,
pcp_world_size, dcp_world_size,
cp_kv_cache_interleave_size)
assert torch.equal(ret, target)
@pytest.fixture
def pcp_mtp_mock_runner():
# set up pcp & mtp related buffers
max_num_reqs = 4
max_model_len = 4096
max_num_tokens = 4096
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.device = 'cpu'
mock_runner.pin_memory = False
# Init model_runner pcp_mtp related buffers
mock_runner.query_start_loc_pcp_full = NPUModelRunner._make_buffer(
mock_runner, max_num_reqs + 1, dtype=torch.int32)
positions_buff = torch.zeros(max_num_tokens,
dtype=torch.int64,
device="cpu")
mock_runner.positions_pcp_full = positions_buff
mock_runner.positions_pcp_full_np = positions_buff.numpy()
mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer(
mock_runner, max_num_tokens, dtype=torch.int32)
mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer(
mock_runner, max_num_reqs, dtype=torch.int32)
mock_runner.decode_threshold = 1
mock_runner.arange_np = np.arange(max_model_len)
mock_runner.input_batch = MagicMock()
mock_runner.input_batch.num_computed_tokens_cpu = \
np.zeros(max_num_reqs, dtype=np.int32)
token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
)
mock_runner.input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
mock_runner.input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
return mock_runner
# yapf: disable
@pytest.mark.parametrize(
"req_ids, num_computed_tokens," \
"token_ids_tensor_list," \
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
[
# prefill
(
['0'], np.array([0]),
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
1, 6, {'0': 6},
torch.tensor([0, 671, 6102, 294, 8760, 344]),
torch.tensor([0, 6])
),
# decode
(
['0'], np.array([6]),
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
1, 2, {'0': 2},
torch.tensor([88907, 0]),
torch.tensor([0, 2])
),
# decode + prefill
(
['0', '1'], np.array([6, 0]),
[
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
],
2, 12, {'0': 2, '1': 10},
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
torch.tensor([0, 2, 12])
),
# decodes + prefills
(
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
[
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
],
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
0, 671, 5335, 1469, 7539, 305, 6397]),
torch.tensor([0, 2, 4, 12, 19])
),
])
# yapf: enable
def test_generate_pcp_mtp_input(
pcp_mtp_mock_runner,
req_ids,
num_computed_tokens,
token_ids_tensor_list,
num_reqs,
total_num_scheduled_tokens,
num_scheduled_tokens,
target_input_ids_pcp_full,
target_query_start_loc_pcp_full,
):
mock_runner = pcp_mtp_mock_runner
token_ids_cpu_tensor = mock_runner.input_batch.token_ids_cpu_tensor
# Set input_batch
mock_runner.input_batch.req_ids = req_ids
mock_runner.input_batch.num_computed_tokens_cpu[:num_computed_tokens.
size] = num_computed_tokens
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
NPUModelRunner._generate_pcp_mtp_input(mock_runner, num_reqs,
total_num_scheduled_tokens,
num_scheduled_tokens)
assert torch.equal(
mock_runner.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
target_input_ids_pcp_full)
assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1],
target_query_start_loc_pcp_full)

View File

@@ -0,0 +1,322 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from vllm_ascend.worker.pcp_utils import PCPManager
@pytest.mark.parametrize(
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
[
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
])
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
num_decodes, use_mla, total_tokens,
expect_not_none):
vllm_config = MagicMock()
vllm_config.model_config = MagicMock()
vllm_config.model_config.use_mla = use_mla
vllm_config.parallel_config.cp_kv_cache_interleave_size = 64
vllm_config.speculative_config.num_speculative_tokens = 0
pcp_manager = PCPManager(pcp_world_size=pcp_size,
pcp_rank=0,
dcp_world_size=dcp_size,
dcp_rank=0,
max_buffer_num_tokens=10000,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
num_computed_tokens = []
num_prompt_tokens = []
num_tokens = []
for i in range(num_reqs):
if i < num_decodes:
num_computed_tokens.append(query_lens[i])
num_prompt_tokens.append(query_lens[i] // 2)
num_tokens.append(query_lens[i])
else:
num_computed_tokens.append(0)
num_prompt_tokens.append(query_lens[i])
num_tokens.append(query_lens[i])
input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens)
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
input_batch.num_tokens = torch.tensor(num_tokens)
query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
input_batch)
if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"
else:
assert result is not None, "Expected to return a metadata object, but got None."
assert hasattr(result, 'num_actual_tokens_pcp_padded')
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
if pcp_size > 1:
assert hasattr(result, 'pcp_allgather_restore_idx')
has_prefill_requests = (num_reqs - num_decodes) > 0
if has_prefill_requests:
assert hasattr(result, 'q_head_idx_tensor')
assert hasattr(result, 'q_tail_idx_tensor')
assert hasattr(result, 'q_full_idx')
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
assert hasattr(result, 'attn_mask_seqlens')
assert hasattr(result, 'head_attn_nomask_seqlens')
assert hasattr(result, 'tail_attn_nomask_seqlens')
if hasattr(result, 'pcp_prefill_mask'
) and result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048, 2048)
else:
if hasattr(result, 'pcp_prefill_mask'):
if result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048,
2048)
@pytest.mark.parametrize(
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
[
# Case 1: prefill only
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
# # Case 2: mix prefill and decode
([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]),
# # Case 3: request which need to be padded
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
# Case 4: single request
([10], 1, [0], [10], 4, 0, [4]),
])
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
num_prompt_tokens, pcp_size, pcp_rank,
expected_pcp_tokens):
vllm_config = MagicMock()
vllm_config.model_config = MagicMock()
vllm_config.speculative_config.num_speculative_tokens = 0
pcp_manager = PCPManager(pcp_world_size=pcp_size,
pcp_rank=0,
dcp_world_size=1,
dcp_rank=0,
max_buffer_num_tokens=10000,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
pin_memory=False)
input_batch = MagicMock()
input_batch.num_reqs = num_reqs
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens,
dtype=np.int32)
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
arange_np = np.arange(10000)
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
np.array(tokens), arange_np, num_reqs, 1)
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
total_pcp_tokens: int = np.sum(pcp_tokens_result)
assert positions_result.shape == (total_pcp_tokens,), \
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
# yapf: disable
@pytest.mark.parametrize(
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
[
# without pcp and dcp
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
# pcp
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
# dcp
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
# pcp + dcp
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
# specify interleave_size
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
]
)
# yapf: enable
def test_get_cp_local_seq_lens(
seq_lens,
pcp_world_size,
dcp_world_size,
cp_kv_cache_interleave_size,
target,
):
vllm_config = MagicMock()
vllm_config.model_config = MagicMock()
vllm_config.speculative_config.num_speculative_tokens = 0
pcp_manager = PCPManager(pcp_world_size=pcp_world_size,
pcp_rank=0,
dcp_world_size=dcp_world_size,
dcp_rank=0,
max_buffer_num_tokens=10000,
max_num_reqs=1000,
device="cpu",
vllm_config=vllm_config,
pin_memory=False)
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
dcp_world_size,
cp_kv_cache_interleave_size)
assert torch.equal(ret, target)
# yapf: disable
@pytest.mark.parametrize(
"req_ids, num_computed_tokens," \
"token_ids_tensor_list," \
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
[
# prefill
(
['0'], np.array([0]),
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
1, 6, {'0': 6},
torch.tensor([0, 671, 6102, 294, 8760, 344]),
torch.tensor([0, 6])
),
# decode
(
['0'], np.array([6]),
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
1, 2, {'0': 2},
torch.tensor([88907, 0]),
torch.tensor([0, 2])
),
# decode + prefill
(
['0', '1'], np.array([6, 0]),
[
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
],
2, 12, {'0': 2, '1': 10},
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
torch.tensor([0, 2, 12])
),
# decodes + prefills
(
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
[
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
],
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
0, 671, 5335, 1469, 7539, 305, 6397]),
torch.tensor([0, 2, 4, 12, 19])
),
])
# yapf: enable
def test_generate_pcp_mtp_input(
req_ids,
num_computed_tokens,
token_ids_tensor_list,
num_reqs,
total_num_scheduled_tokens,
num_scheduled_tokens,
target_input_ids_pcp_full,
target_query_start_loc_pcp_full,
):
max_num_reqs = 4
max_model_len = 4096
max_num_tokens = 4096
vllm_config = MagicMock()
vllm_config.model_config = MagicMock()
vllm_config.speculative_config.num_speculative_tokens = 1
vllm_config.scheduler_config.max_num_seqs = max_num_reqs
vllm_config.scheduler_config.max_num_batched_tokens = max_model_len
pcp_manager = PCPManager(pcp_world_size=2,
pcp_rank=0,
dcp_world_size=1,
dcp_rank=0,
max_buffer_num_tokens=max_num_tokens,
max_num_reqs=max_num_reqs,
device="cpu",
vllm_config=vllm_config,
pin_memory=False)
arange_np = np.arange(max_model_len)
input_batch = MagicMock()
input_batch.num_computed_tokens_cpu = \
np.zeros(max_num_reqs, dtype=np.int32)
token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
)
input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor
# Set input_batch
input_batch.req_ids = req_ids
input_batch.num_computed_tokens_cpu[:num_computed_tokens.
size] = num_computed_tokens
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
num_scheduled_tokens, False,
input_batch, arange_np)
assert torch.equal(
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
target_input_ids_pcp_full)
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
target_query_start_loc_pcp_full)

View File

@@ -279,9 +279,9 @@ class EagleProposer(VllmEagleProposer):
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]

View File

@@ -179,7 +179,7 @@ class MtpProposer(EagleProposer):
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = \
self.runner.long_seq_metadata
self.runner.pcp_manager.long_seq_metadata
common_attn_metadata.block_table_tensor = \
self.runner.input_batch.block_table[0].get_device_tensor()[
:num_reqs * self.decode_threshold]
@@ -286,9 +286,9 @@ class MtpProposer(EagleProposer):
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
@@ -303,7 +303,7 @@ class MtpProposer(EagleProposer):
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
@@ -751,8 +751,8 @@ class MtpProposer(EagleProposer):
hidden_states = hidden_states[:num_tokens]
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.
pcp_allgather_restore_idx[:hidden_states.shape[0]])
hidden_states, 0, self.runner.pcp_manager.
pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]])
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -797,13 +797,13 @@ class MtpProposer(EagleProposer):
# (_generate_pcp_mtp_input), and use updated slot_indices
# to get corresponding slot_mapping in each step.
num_reject_tokens = torch.tensor(
self.runner.cu_num_tokens_pcp_full,
self.runner.pcp_manager.cu_num_tokens_pcp_full,
dtype=torch.int32).to(
self.device) - ori_last_token_indices - 1
num_accept_tokens = \
query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens
mtp_slot_mapping = self.runner.mtp_slot_pad
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
# slot_mapping index base offset:
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
@@ -889,7 +889,7 @@ class MtpProposer(EagleProposer):
self.hidden_states[:hidden_states.shape[0]] = hidden_states
if self.pcp_size * self.dcp_size > 1:
# update local seq_len and batch_seq_mask
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + step + 1,
self.pcp_size,
self.dcp_size,

View File

@@ -24,7 +24,7 @@ from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union
import numpy as np
import torch
@@ -78,8 +78,7 @@ from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
# yapf conflicts with isort for this block
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
@@ -109,6 +108,7 @@ from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
lmhead_tp_enable, maybe_trans_nz,
set_weight_prefetch_method)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.worker.pcp_utils import PCPManager
from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
@@ -202,6 +202,26 @@ class NPUModelRunner(GPUModelRunner):
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = (self.max_num_tokens +
self.max_num_reqs * 2 * self.pcp_size)
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens,
dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens,
dtype=torch.int64)
self.sampler = AscendSampler()
self.attn_mask = None
self.attn_state = None
@@ -262,32 +282,6 @@ class NPUModelRunner(GPUModelRunner):
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
self.uniform_decode_query_len)
set_mc2_mask(vllm_config, self.device)
self.pcp_allgather_restore_idx = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
[] for _ in range(self.pcp_size)
]
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
self.pcp_padded_slot_mapping = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.query_start_loc_pcp_full = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
@@ -359,6 +353,7 @@ class NPUModelRunner(GPUModelRunner):
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
self.reorder_batch_threshold: int | None = None
self.long_seq_metadata = None
def _init_device_properties(self) -> None:
self.num_sms = None
@@ -508,49 +503,6 @@ class NPUModelRunner(GPUModelRunner):
return self.attn_mask_builder.get_mla_mask(self.dtype)
return self.attn_mask_builder.get_splitfuse_attn_mask()
def generate_kv_idx(self, scheduler_output):
if not self.pcp_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
is_prefill = self.input_batch.num_computed_tokens_cpu[
i] < self.input_batch.num_prompt_tokens[i]
if is_prefill:
num_cp_padded_scheduled_tokens = cdiv(
num_scheduled_tokens,
2 * self.pcp_size) * (2 * self.pcp_size)
full_indices = list(
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.max_num_reqs))
chunk_size = num_cp_padded_scheduled_tokens // (2 *
self.pcp_size)
num_added_recover_tokens = len(
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
for rank in range(self.pcp_size):
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[rank * chunk_size +
num_added_recover_tokens:(rank + 1) *
chunk_size + num_added_recover_tokens])
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[num_cp_padded_scheduled_tokens -
(rank + 1) * chunk_size +
num_added_recover_tokens:
num_cp_padded_scheduled_tokens -
rank * chunk_size +
num_added_recover_tokens])
cp_kv_recover_idx_for_chunk = torch.from_numpy(
np.concatenate(
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
non_blocking=True)
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
torch.float32).argsort().to(torch.int32)
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -574,43 +526,70 @@ class NPUModelRunner(GPUModelRunner):
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
positions_np = np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
total_num_pcp_pads = 0
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
tokens_before_update = tokens.copy()
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
total_num_pcp_pads = torch.sum(self.num_pcp_pads[:num_reqs]).item()
else:
position_pcp, pcp_unpad_mask = None, None
self.num_pcp_pads[:num_reqs] = 0
max_num_scheduled_tokens = max(tokens)
if not scheduler_output.scheduled_spec_decode_tokens:
num_valid_tokens = np.array(tokens, dtype=np.int32)
else:
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip((tokens_before_update if self.
pcp_size > 1 else tokens), req_ids)
for num_tokens, i in zip(tokens, req_ids)
],
dtype=np.int32)
# Get the attention state.
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.attn_mask = self._make_attention_mask(attn_state)
# Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.pcp_manager.generate_pcp_mtp_input(
num_reqs, total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens, with_prefill,
self.input_batch, self.arange_np, req_indices, positions_np,
cu_num_tokens)
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.pcp_manager.generate_kv_idx(scheduler_output,
self.input_batch)
num_scheduled_tokens[:
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs],
self.arange_np,
self.input_batch.num_reqs,
self.reorder_batch_threshold,
)
# Re-update after PCP split sequences.
total_num_scheduled_tokens = sum(num_scheduled_tokens)
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
cu_num_tokens, _ = self._get_cumsum_and_arange(
num_scheduled_tokens)
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
position_pcp[:total_num_scheduled_tokens],
out=positions_np,
)
max_num_scheduled_tokens = max(tokens)
if (self.use_aclgraph and total_num_scheduled_tokens
<= self.cudagraph_batch_sizes[-1]):
# Add padding to the batch size.
@@ -627,17 +606,6 @@ class NPUModelRunner(GPUModelRunner):
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Get the attention state.
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get info across DP ranks.
@@ -646,7 +614,7 @@ class NPUModelRunner(GPUModelRunner):
(maybe_padded_num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill)
self.with_prefill = with_prefill
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
# We should consider removing maybe_padded_num_tokens later
num_input_tokens = maybe_padded_num_tokens
@@ -655,24 +623,8 @@ class NPUModelRunner(GPUModelRunner):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
if self.pcp_size > 1:
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
position_pcp[:total_num_scheduled_tokens],
out=positions_np)
else:
self.positions.np[:total_num_scheduled_tokens] = positions_np
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self._calc_mrope_positions(scheduler_output)
@@ -766,21 +718,11 @@ class NPUModelRunner(GPUModelRunner):
self.seq_lens.gpu[num_reqs:].fill_(0)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
cu_num_tokens)
self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions.copy_to_gpu()
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(attn_state)
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
attn_metadata: dict[str, Any] = {}
# Record the index of requests that should not be sampled,
@@ -914,9 +856,8 @@ class NPUModelRunner(GPUModelRunner):
# TODO: Support prompt logprobs.
spec_decode_metadata = None
if self.pcp_size * self.dcp_size > 1:
logits_indices = torch.from_numpy(
cu_num_tokens
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
logits_indices = self.pcp_manager.get_logits_indices(
cu_num_tokens, num_reqs)
logits_indices = logits_indices.pin_memory().to(
self.device, non_blocking=True)
else:
@@ -938,8 +879,10 @@ class NPUModelRunner(GPUModelRunner):
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens,
self.num_pcp_pads[:num_reqs].numpy())
num_draft_tokens,
cu_num_tokens,
num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs]
if self.pcp_size > 1 else None)
logits_indices = spec_decode_metadata.logits_indices
# For DECODE only cuda graph of some attention backends (e.g., GDN).
@@ -961,23 +904,10 @@ class NPUModelRunner(GPUModelRunner):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens, with_prefill,
req_indices, positions_np, cu_num_tokens)
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
slot_mapping_size = (total_num_scheduled_tokens
if self.pcp_size == 1 else
total_num_scheduled_tokens * self.pcp_size -
total_num_pcp_pads)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
@@ -993,30 +923,30 @@ class NPUModelRunner(GPUModelRunner):
device=self.device,
)
else:
maybe_pcp_full_tokens = (
num_input_tokens if self.pcp_size == 1 else
total_num_scheduled_tokens * self.pcp_size -
sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]))
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
slot_mapping_for_pcp = blk_table.slot_mapping.gpu[:
long_seq_metadata
.
num_actual_tokens_pcp_padded]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask
.
shape[
0]]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[
pcp_unpad_mask] = slot_mapping_for_pcp[:
slot_mapping_size]
slot_mapping_for_pcp[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
slot_mapping_for_pcp
slot_mapping = blk_table.slot_mapping.gpu[:
maybe_pcp_full_tokens]
if self.pcp_size * self.dcp_size == 1:
slot_mapping[
total_num_scheduled_tokens:num_input_tokens].fill_(-1)
slot_mapping = blk_table.slot_mapping.gpu
if self.pcp_size * self.dcp_size > 1:
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
total_num_scheduled_tokens, self.query_lens,
self.attn_mask, self.input_batch)
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
slot_mapping = slot_mapping[:maybe_pcp_full_tokens]
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
total_num_scheduled_tokens,
slot_mapping,
)
blk_table.slot_mapping.gpu[:self.pcp_manager.
num_actual_tokens_pcp_padded] = slot_mapping
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
# has been split to multiple parts, and there are 3 parts that is related to this
@@ -1055,7 +985,7 @@ class NPUModelRunner(GPUModelRunner):
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.gpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=slot_mapping_size,
num_actual_tokens=total_num_scheduled_tokens,
num_input_tokens=num_input_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
@@ -1069,8 +999,9 @@ class NPUModelRunner(GPUModelRunner):
attn_state=self.attn_state,
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
max_seq_len=0)
prefill_context_parallel_metadata=self.long_seq_metadata,
max_seq_len=0,
)
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
@@ -1080,8 +1011,10 @@ class NPUModelRunner(GPUModelRunner):
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[:
num_reqs]
ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[:
num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
@@ -1097,13 +1030,17 @@ class NPUModelRunner(GPUModelRunner):
ori_query_lens[:num_decode_reqs], dim=0))
common_attn_metadata.block_table_tensor = \
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
assert self.long_seq_metadata is not None
self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
if 'pad_size' in locals() and pad_size > 0:
ori_query_lens_cpu[-pad_size:] = \
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
long_seq_metadata.max_query_len_pcp_full = \
self.long_seq_metadata.max_query_len_pcp_full = \
ori_query_lens_cpu.max().item()
if self.speculative_config and \
self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata
@@ -1193,19 +1130,12 @@ class NPUModelRunner(GPUModelRunner):
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded //
self.pcp_size], 0)
hidden_states = torch.index_select(
hidden_states, 0,
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
return hidden_states
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
hidden_states)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens):
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
@@ -1231,7 +1161,7 @@ class NPUModelRunner(GPUModelRunner):
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
num_pcp_pads: np.ndarray,
num_pcp_pads: np.ndarray | None,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
@@ -1846,7 +1776,9 @@ class NPUModelRunner(GPUModelRunner):
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.attn_mask,
self.input_batch)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group().world_size
dcp_world_size = get_dcp_group().world_size
@@ -2890,365 +2822,6 @@ class NPUModelRunner(GPUModelRunner):
parent_module_name):
super().capture_model()
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
num_decode_tokens = sum(tokens[:num_decode_reqs])
num_padded_scheduled_tokens = np.ceil(
tokens /
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = (
tokens[:num_decode_reqs] * self.pcp_size)
self.num_pcp_pads[:num_reqs] = torch.tensor(
num_padded_scheduled_tokens - tokens)
cu_padded_tokens, pcp_padded_arange = \
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
unpad_mask = torch.from_numpy(
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
pcp_tokens)
def get_current_rank_positions(cu_tokens, rank):
positions_start_loc = np.zeros_like(cu_tokens)
positions_start_loc[1:] = cu_tokens[:-1]
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = positions_start_loc + \
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
np.repeat(head_start_loc, pcp_chunk_sizes)
# Decode reqs do not have tail chunks.
positions[~pcp_head_chunk_mask] = \
pcp_chunk_arange[num_decode_tokens:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
return positions
positions = get_current_rank_positions(
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
# Decode tokens are duplicate and their positions always be 0.
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
tokens[:num_decode_reqs])[1]
all_positions = [
get_current_rank_positions(cu_padded_tokens, rank_i)
for rank_i in range(self.pcp_size)
]
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
all_positions_tensor.float().argsort().long(), non_blocking=True)
return pcp_tokens, positions, unpad_mask
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
dcp_world_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each (p/d)cp rank.
"""
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size,
dtype=torch.int32).unsqueeze(0).repeat(
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
if self.pcp_size > 1 and self.speculative_config else self.query_lens
num_decodes = (query_lens <= self.decode_threshold).sum().item()
num_prefills = num_reqs - num_decodes
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
self.dcp_size
],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_cp_local_seq_lens(
torch.tensor(context_lens) - decode_idx,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if num_decodes:
num_decodes_flatten = \
self.query_lens[:num_decodes].sum().item()
if self.query_lens[:num_decodes].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(num_decodes):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - self.query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if num_prefills:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(num_decodes + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy())
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
chunk_seqlens = []
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
q_req_offset = 0
kv_req_offset = 0
q_head_chunk_id = self.pcp_rank
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
for i, seq_len in enumerate(self.query_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len *
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len *
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_size
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst),
dtype=dtype,
device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor':
kv_with_q_head_nomask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
'kv_with_q_tail_nomask_idx_tensor':
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
}
for key, value in self.kv_idx_names.items():
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
dtype=torch.int32)
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
pcp_prefill_mask = self.attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
num_actual_tokens_pcp_padded]
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor']
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_mask_idx_tensor']
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor']
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_mask_idx_tensor']
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
'attn_mask_seqlens']
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
def _generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
with_prefill: bool = True,
req_indices=None,
positions_np=None,
cu_num_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
but mtp need to shift original input_ids before pcp splitting,
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_pcp_full_np)
token_indices_pcp_full = (
positions_pcp_full_np +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:num_reqs]
for req_idx in range(num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
ori_req_indice,
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
positions_split[req_idx] = np.append(
ori_position,
np.arange(ori_position[-1] + 1,
ori_position[-1] + self.decode_threshold - 1))
req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split)
self.input_batch.block_table.compute_slot_mapping(
req_indices_mtp, positions_mtp)
mtp_slot_ori = self.input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_size] = True
mtp_slot_pad = \
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
def _prepare_multimodal_fields(self):
"""
Ensures specific multimodal tensors are on CPU.

View File

@@ -0,0 +1,686 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
from typing import List
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class PCPManager:
"""
Manager for Prefill Context Parallelism (PCP) metadata and buffers.
This manager encapsulates all PCP-related buffers and logic so that the
ModelRunner can access them via `self.pcp_manager`.
"""
def __init__(
self,
pcp_world_size: int,
pcp_rank: int,
dcp_world_size: int,
dcp_rank: int,
max_buffer_num_tokens: int,
max_num_reqs: int,
device: torch.device,
vllm_config: VllmConfig,
pin_memory: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank
self.dcp_world_size = dcp_world_size
self.dcp_world_rank = dcp_rank
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
self.vllm_config = vllm_config
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
self.device = device
self.pcp_allgather_restore_idx = CpuGpuBuffer(
max_buffer_num_tokens,
dtype=torch.int64,
device=device,
pin_memory=pin_memory,
)
self.pcp_padded_slot_mapping = torch.full(
(max_buffer_num_tokens, ),
fill_value=-1,
dtype=torch.int32,
device=device,
)
self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ),
device="cpu",
dtype=torch.int64)
self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy()
self.pcp_unpad_mask_cpu_tensor = torch.zeros(
(max_buffer_num_tokens, ),
device="cpu",
dtype=torch.bool,
)
self.num_actual_tokens_pcp_padded = 0
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
[] for _ in range(self.pcp_world_size)
]
self.full_indices = list(
range(self.max_num_tokens * self.pcp_world_size *
self.dcp_world_size + self.pcp_world_size *
self.dcp_world_size * self.max_num_reqs))
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
def _get_cumsum_and_arange(
self,
num_scheduled_tokens: np.ndarray,
arange_np: np.ndarray,
cumsum_dtype: np.dtype | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_scheduled_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange
def update_tokens_for_pcp(
self,
num_scheduled_tokens: np.ndarray,
arange_np: np.ndarray,
num_reqs: int,
reorder_batch_threshold: int | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Update token counts and positions for Prefill Context Parallelism (PCP).
When using Prefill Context Parallelism, each request's prefill sequence is
split across multiple PCP ranks. The splitting strategy used here is the
"DualChunkSwap" style: each request's (padded) sequence is split into
2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved
head/tail pattern to balance load.
This function:
- Computes how many tokens each request should be processed by the current
PCP rank (pcp_tokens).
- Computes the flattened positions of those tokens within the local
padded buffer (pcp_positions).
- Updates runner state arrays used to restore original order and mask out
padded tokens after allgather:
- self.num_pcp_pads_cpu: number of pads added per request
- self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the
padded allgather buffer
- self.pcp_allgather_restore_idx: index array used to restore original
ordering after per-rank allgather and interleaving.
Args:
num_scheduled_tokens: 1D numpy array of length num_reqs containing
the number of new tokens scheduled per request.
arange_np: 1D numpy array of length max_buffer_num_tokens used for
efficient batched arange operations.
num_reqs: Total number of requests in the batch.
reorder_batch_threshold: Threshold for decode vs prefill requests.
Returns:
Tuple (pcp_tokens, pcp_positions):
- pcp_tokens: number of tokens per request that this PCP rank will
actually process (after splitting / replication).
- pcp_positions: flattened positions for those tokens on this rank,
used to build the positions buffer for the model.
Example:
>>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp.
>>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7])
>>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5])
>>> Meanwhile, the following results are same for each pcp rank
>>> self.num_pcp_pads_cpu
[1, 3, 0]
>>> self.pcp_unpad_mask_cpu
[True, False, True, True, True, True, True, False, False,
False, True, True, True, True, True, True, True, True]
>>> self.pcp_allgather_resotre_idx
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
"""
assert reorder_batch_threshold is not None, (
"PCP depends on reorder batch to split decode and prefill requests."
)
num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold)
num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs])
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
# We first pad each request's token count up to that multiple.
num_padded_scheduled_tokens = np.ceil(
num_scheduled_tokens / (2 * self.pcp_world_size)).astype(
np.int32) * (2 * self.pcp_world_size)
# PCP does not split decode requests. For decode requests, we instead
# duplicate the scheduled tokens across the pcp_world_size ranks.
num_padded_scheduled_tokens[:num_decode_reqs] = (
num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size)
# Record how many pads were added per request (padded - original).
self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens -
num_scheduled_tokens)
# cu_padded_tokens: cumulative sum of padded token counts,
# pcp_padded_arange: per-request arange flattened for padded tokens.
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(
num_padded_scheduled_tokens, arange_np)
# Build the mask that marks which positions in the padded allgather buffer
# correspond to real (unpadded) tokens.
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
pcp_padded_arange < np.repeat(num_scheduled_tokens,
num_padded_scheduled_tokens))
unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens *
self.pcp_world_size]
unpad_mask_decode = unpad_mask_decode.reshape(
[-1, self.pcp_world_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size
# Compute per-request "chunk sizes" for the head/tail splitting.
# For prefill requests, we further split the pcp_tokens into two chunks
# (head and tail). For decode requests, the chunk equals pcp_tokens.
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
# Build arange-style helpers for pcp tokens and chunk sizes:
# - pcp_arange gives indices repeated for each token in pcp_tokens
# - pcp_chunk_arange gives indices repeated for each position inside chunks
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np)
_, pcp_chunk_arange = self._get_cumsum_and_arange(
pcp_chunk_sizes, arange_np)
# Mask that marks whether a position belongs to the head chunk (True)
# or the tail chunk (False). For decode requests, tail chunk won't exist
# and is handled specially below.
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
pcp_tokens)
def get_current_rank_positions(positions_start_loc: int | np.ndarray,
rank: int):
"""
Compute flattened positions for the given rank with a given start
offset for each request (positions_start_loc).
- For head chunks: start at positions_start_loc + rank * chunk_size.
- For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank -
1) * chunk_size.
- For decode requests: no tail chunks; their positions are filled from the
contiguous (unpadded) `tokens` arange instead (handled after).
"""
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = (
positions_start_loc +
(2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes)
# Fill head positions using chunk arange offset by head_start_loc.
positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(
head_start_loc, pcp_chunk_sizes)
# Fill tail positions. Note decode requests do not have tail chunks,
# so the tail filling is only for prefill positions.
positions[~pcp_head_chunk_mask] = (
pcp_chunk_arange[num_decode_tokens:] +
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:])
return positions
positions = get_current_rank_positions(0, self.pcp_world_rank)
# Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel.
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:num_decode_reqs], arange_np)[1]
# Build the restore index used after allgather.
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
padded_pos_start_loc[0] = 0
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i)
for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = (
all_positions.argsort())
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
return (
pcp_tokens[:num_reqs],
positions,
)
def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int):
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
self.num_pcp_pads_cpu_tensor[:num_reqs] - 1)
def get_discard_request_mask(
self,
num_computed_tokens_cpu: np.ndarray,
num_scheduled_tokens: np.ndarray,
num_reqs: int,
num_tokens_np: np.ndarray,
):
return (num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens * self.pcp_world_size -
self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np
def get_padded_slot_mapping(self, num_tokens: int,
slot_mapping: torch.Tensor):
# After pcp allgather and restore, there are padded tokens in kv,
# so we need pad slotmapping for alignment.
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens *
self.
pcp_world_size]
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens *
self.pcp_world_size]
pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping
return pcp_padded_slot_mapping
def get_restore_hidden_states(
self,
hidden_states: torch.Tensor,
):
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph.
from vllm.distributed.parallel_state import get_pcp_group
hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded //
self.pcp_world_size],
0,
)
restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states.
shape[0]]
return torch.index_select(
hidden_states,
0,
restore_idx,
)
def generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
with_prefill: bool = True,
input_batch=None,
arange_np=None,
req_indices=None,
positions_np=None,
cu_num_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
but mtp need to shift original input_ids before pcp splitting,
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full]
np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_pcp_full_np)
token_indices_pcp_full = (
positions_pcp_full_np +
req_indices_pcp_full * input_batch.token_ids_cpu.shape[1])
torch.index_select(input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu(
total_num_scheduled_tokens_pcp_full)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:num_reqs]
for req_idx in range(num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
ori_req_indice,
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
positions_split[req_idx] = np.append(
ori_position,
np.arange(ori_position[-1] + 1,
ori_position[-1] + self.decode_threshold - 1))
req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split)
input_batch.block_table.compute_slot_mapping(
req_indices_mtp, positions_mtp)
mtp_slot_ori = input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_world_size] = True
mtp_slot_pad = \
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
dcp_world_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each (p/d)cp rank.
"""
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size,
dtype=torch.int32).unsqueeze(0).repeat(
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def generate_kv_idx(self, scheduler_output, input_batch):
if not self.pcp_world_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[]
for _ in range(self.pcp_world_size)
]
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
is_prefill = input_batch.num_computed_tokens_cpu[
i] < input_batch.num_prompt_tokens[i]
if is_prefill:
num_cp_padded_scheduled_tokens = cdiv(
num_scheduled_tokens,
2 * self.pcp_world_size) * (2 * self.pcp_world_size)
chunk_size = num_cp_padded_scheduled_tokens // (
2 * self.pcp_world_size)
num_added_recover_tokens = len(
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size
for rank in range(self.pcp_world_size):
self.cp_kv_recover_idx_for_chunk[rank].extend(
self.full_indices[rank * chunk_size +
num_added_recover_tokens:(rank + 1) *
chunk_size +
num_added_recover_tokens])
self.cp_kv_recover_idx_for_chunk[rank].extend(
self.full_indices[num_cp_padded_scheduled_tokens -
(rank + 1) * chunk_size +
num_added_recover_tokens:
num_cp_padded_scheduled_tokens -
rank * chunk_size +
num_added_recover_tokens])
cp_kv_recover_idx_for_chunk = torch.from_numpy(
np.concatenate(
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
non_blocking=True)
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
torch.float32).argsort().to(torch.int32)
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
attn_mask, input_batch):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \
if self.pcp_world_size > 1 and self.speculative_config else query_lens
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
num_prefills = num_reqs - num_decodes
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_tokens[:num_decodes]
prefill_context_lens = input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_world_size,
self.dcp_world_size
],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_cp_local_seq_lens(
torch.tensor(context_lens) - decode_idx,
self.pcp_world_size,
self.dcp_world_size,
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if num_decodes:
num_decodes_flatten = \
query_lens[:num_decodes].sum().item()
if query_lens[:num_decodes].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(num_decodes):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if num_prefills:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(num_decodes + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy())
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
chunk_seqlens = []
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
q_req_offset = 0
kv_req_offset = 0
q_head_chunk_id = self.pcp_world_rank
q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank
for i, seq_len in enumerate(query_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len *
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len *
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_world_size
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst),
dtype=dtype,
device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor':
kv_with_q_head_nomask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
'kv_with_q_tail_nomask_idx_tensor':
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
}
for key, value in self.kv_idx_names.items():
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
dtype=torch.int32)
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
pcp_prefill_mask = attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
num_actual_tokens_pcp_padded]
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor']
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_mask_idx_tensor']
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor']
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_mask_idx_tensor']
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
'attn_mask_seqlens']
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata