[Feature] Refactor PCP &DCP related code (#5214)
### What this PR does / why we need it?
Refactor pcp& dcp related code. we use pcp_manager class to Unifiy
Manage pcp & dcp . as we do this , many code can be deleted from
model_runner, and can avoid break pcp & dcp by other developments.
RFC:https://github.com/vllm-project/vllm-ascend/issues/5449
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -225,21 +225,22 @@ class TestMtpProposer:
|
||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
||||
mock_deps.runner.pcp_size = 2
|
||||
mock_deps.runner.dcp_size = 1
|
||||
mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer(
|
||||
mock_deps.runner.pcp_manager = MagicMock()
|
||||
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
|
||||
32,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.input_ids_pcp_full.cpu = \
|
||||
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
|
||||
torch.arange(32, dtype=torch.int32)
|
||||
mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
5,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.query_start_loc_pcp_full.cpu = \
|
||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
|
||||
torch.tensor([0, 8, 16, 24, 32])
|
||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
||||
|
||||
@@ -1,473 +0,0 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
|
||||
[
|
||||
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
|
||||
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
|
||||
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
|
||||
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
|
||||
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
|
||||
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
|
||||
])
|
||||
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
num_decodes, use_mla, total_tokens,
|
||||
expect_not_none):
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = pcp_size
|
||||
mock_runner.dcp_size = dcp_size
|
||||
mock_runner.decode_threshold = 4
|
||||
mock_runner.pcp_rank = 0
|
||||
mock_runner.device = torch.device('cpu')
|
||||
mock_runner.dtype = torch.float32
|
||||
|
||||
mock_runner.parallel_config = MagicMock()
|
||||
mock_runner.parallel_config.cp_kv_cache_interleave_size = 64
|
||||
|
||||
mock_runner.vllm_config = MagicMock()
|
||||
mock_runner.vllm_config.model_config = MagicMock()
|
||||
mock_runner.vllm_config.model_config.use_mla = use_mla
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = num_reqs
|
||||
mock_runner.speculative_config = None
|
||||
|
||||
num_computed_tokens = []
|
||||
num_prompt_tokens = []
|
||||
num_tokens = []
|
||||
|
||||
for i in range(num_reqs):
|
||||
if i < num_decodes:
|
||||
num_computed_tokens.append(query_lens[i])
|
||||
num_prompt_tokens.append(query_lens[i] // 2)
|
||||
num_tokens.append(query_lens[i])
|
||||
else:
|
||||
num_computed_tokens.append(0)
|
||||
num_prompt_tokens.append(query_lens[i])
|
||||
num_tokens.append(query_lens[i])
|
||||
|
||||
mock_runner.input_batch.num_computed_tokens_cpu = torch.tensor(
|
||||
num_computed_tokens)
|
||||
mock_runner.input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
||||
mock_runner.input_batch.num_tokens = torch.tensor(num_tokens)
|
||||
|
||||
mock_runner.query_lens = torch.tensor(query_lens)
|
||||
|
||||
mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
|
||||
mock_runner.pcp_allgather_restore_idx = torch.arange(total_tokens * 2)
|
||||
mock_runner.cp_kv_recover_idx_for_chunk = torch.arange(total_tokens)
|
||||
|
||||
mock_runner.long_seq_metadata = None
|
||||
mock_runner.num_actual_tokens_pcp_padded = 0
|
||||
mock_runner.kv_idx_names = {}
|
||||
mock_runner.extra_long_seq_kwargs = {}
|
||||
mock_runner.attn_mask = None
|
||||
mock_runner.q_head_idx_tensor = None
|
||||
mock_runner.q_tail_idx_tensor = None
|
||||
mock_runner.q_full_idx = None
|
||||
|
||||
method = NPUModelRunner._generate_pcp_metadata.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
result = method(total_tokens)
|
||||
|
||||
if not expect_not_none:
|
||||
assert result is None, f"Expected to return None, but got {type(result)}"
|
||||
else:
|
||||
assert result is not None, "Expected to return a metadata object, but got None."
|
||||
|
||||
assert hasattr(result, 'num_actual_tokens_pcp_padded')
|
||||
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
|
||||
|
||||
if pcp_size > 1:
|
||||
assert hasattr(result, 'pcp_allgather_restore_idx')
|
||||
|
||||
has_prefill_requests = (num_reqs - num_decodes) > 0
|
||||
if has_prefill_requests:
|
||||
assert hasattr(result, 'q_head_idx_tensor')
|
||||
assert hasattr(result, 'q_tail_idx_tensor')
|
||||
assert hasattr(result, 'q_full_idx')
|
||||
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
|
||||
assert hasattr(result, 'attn_mask_seqlens')
|
||||
assert hasattr(result, 'head_attn_nomask_seqlens')
|
||||
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
||||
|
||||
if hasattr(result, 'pcp_prefill_mask'
|
||||
) and result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048, 2048)
|
||||
else:
|
||||
if hasattr(result, 'pcp_prefill_mask'):
|
||||
if result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048,
|
||||
2048)
|
||||
|
||||
|
||||
def test_generate_pcp_metadata_edge_cases():
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.pcp_size = 2
|
||||
mock_runner.dcp_size = 1
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = 0
|
||||
mock_runner.query_lens = torch.tensor([10, 20, 30])
|
||||
|
||||
assert (mock_runner.input_batch.num_reqs
|
||||
or mock_runner.query_lens.size(0)) == 3
|
||||
|
||||
mock_runner.input_batch.num_reqs = 100
|
||||
mock_runner.query_lens = torch.ones(100) * 1000
|
||||
|
||||
for rank in [0, 1]:
|
||||
mock_runner.pcp_rank = rank
|
||||
q_head_chunk_id = rank
|
||||
q_tail_chunk_id = 2 * 2 - 1 - rank
|
||||
assert q_head_chunk_id == rank
|
||||
assert q_tail_chunk_id == 3 - rank
|
||||
|
||||
|
||||
def test_pcp_allgather_restore_idx_slicing():
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.pcp_size = 2
|
||||
mock_runner.pcp_allgather_restore_idx = torch.arange(1000)
|
||||
|
||||
total_num_scheduled_tokens = 200
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * 2
|
||||
|
||||
expected_slice = mock_runner.pcp_allgather_restore_idx[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
assert len(expected_slice) == 400
|
||||
assert expected_slice[0] == 0
|
||||
assert expected_slice[-1] == 399
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \
|
||||
"pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens",
|
||||
[
|
||||
# Case 1: prefill only
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]),
|
||||
|
||||
# Case 2: mix prefill and decode (with spec decode)
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]),
|
||||
|
||||
# Case 3: request which need to be padded
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]),
|
||||
|
||||
# Case 4: single request
|
||||
([10], 1, [0], [10], 4, 0, 1, [4]),
|
||||
])
|
||||
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
num_prompt_tokens, pcp_size, pcp_rank,
|
||||
decode_threshold, expected_pcp_tokens):
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = pcp_size
|
||||
mock_runner.pcp_rank = pcp_rank
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = num_reqs
|
||||
mock_runner.input_batch.num_computed_tokens_cpu = np.array(
|
||||
num_computed_tokens, dtype=np.int32)
|
||||
mock_runner.input_batch.num_prompt_tokens = np.array(num_prompt_tokens,
|
||||
dtype=np.int32)
|
||||
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
|
||||
mock_runner.num_pcp_pads = [0] * num_reqs
|
||||
mock_runner.arange_np = np.arange(10000)
|
||||
mock_runner.decode_threshold = decode_threshold
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
|
||||
pcp_tokens_result, positions_result, unpad_mask_result = mock_runner._update_tokens_for_pcp(
|
||||
tokens)
|
||||
|
||||
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
||||
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
||||
|
||||
total_pcp_tokens: int = np.sum(pcp_tokens_result)
|
||||
assert positions_result.shape == (total_pcp_tokens,), \
|
||||
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
|
||||
|
||||
padded_tokens = [
|
||||
(t + 2 * pcp_size - 1) // (2 * pcp_size) *
|
||||
(2 * pcp_size) if num_computed_tokens[i] == 0 else t * pcp_size
|
||||
for i, t in enumerate(tokens)
|
||||
]
|
||||
total_padded_tokens: int = np.sum(padded_tokens)
|
||||
assert unpad_mask_result.shape[0] == total_padded_tokens, \
|
||||
f"unpad_mask size mismatch: expected {total_padded_tokens}, got {unpad_mask_result.shape[0]}"
|
||||
|
||||
|
||||
def test_update_tokens_for_pcp_with_padding():
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = 4
|
||||
mock_runner.pcp_rank = 0
|
||||
|
||||
mock_runner.arange_np = np.arange(10000)
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = 3
|
||||
mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0, 0],
|
||||
dtype=np.int32)
|
||||
mock_runner.input_batch.num_prompt_tokens = np.array([5, 9, 13],
|
||||
dtype=np.int32)
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
|
||||
tokens = [5, 9, 13]
|
||||
|
||||
pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp(
|
||||
tokens)
|
||||
|
||||
expected_pcp_tokens = [2, 4, 4]
|
||||
assert np.array_equal(pcp_tokens, expected_pcp_tokens), \
|
||||
f"Expected {expected_pcp_tokens}, got {pcp_tokens}"
|
||||
|
||||
expected_pads = [3, 7, 3]
|
||||
assert np.array_equal(mock_runner.num_pcp_pads, expected_pads), \
|
||||
f"Expected padding {expected_pads}, got {mock_runner.num_pcp_pads}"
|
||||
|
||||
|
||||
def test_update_tokens_for_pcp_unpad_mask():
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.pcp_size = 4
|
||||
mock_runner.pcp_rank = 0
|
||||
|
||||
mock_runner.arange_np = np.arange(10000)
|
||||
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_reqs = 2
|
||||
mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0],
|
||||
dtype=np.int32)
|
||||
mock_runner.input_batch.num_prompt_tokens = np.array([5, 7],
|
||||
dtype=np.int32)
|
||||
|
||||
mock_runner.num_pcp_pads = [0, 0]
|
||||
mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__(
|
||||
mock_runner, NPUModelRunner)
|
||||
|
||||
tokens = [5, 7]
|
||||
|
||||
pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp(
|
||||
tokens)
|
||||
|
||||
assert unpad_mask.dtype == torch.bool, \
|
||||
f"unpad_mask should be bool, got {unpad_mask.dtype}"
|
||||
|
||||
padded_tokens = [8, 8]
|
||||
expected_length = sum(padded_tokens)
|
||||
assert unpad_mask.shape[0] == expected_length, \
|
||||
f"unpad_mask length mismatch: expected {expected_length}, got {unpad_mask.shape[0]}"
|
||||
|
||||
expected_mask = [True] * 5 + [False] * 3 + [True] * 7 + [False] * 1
|
||||
actual_mask = unpad_mask.numpy().tolist()
|
||||
assert actual_mask == expected_mask, \
|
||||
f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}"
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
|
||||
[
|
||||
# without pcp and dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
|
||||
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
|
||||
# pcp
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
|
||||
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
|
||||
# dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
|
||||
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
|
||||
# pcp + dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
|
||||
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
|
||||
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
|
||||
# specify interleave_size
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
|
||||
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
|
||||
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
|
||||
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
|
||||
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
|
||||
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
|
||||
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_get_cp_local_seq_lens(
|
||||
seq_lens,
|
||||
pcp_world_size,
|
||||
dcp_world_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
target,
|
||||
):
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
ret = NPUModelRunner._get_cp_local_seq_lens(mock_runner, seq_lens,
|
||||
pcp_world_size, dcp_world_size,
|
||||
cp_kv_cache_interleave_size)
|
||||
assert torch.equal(ret, target)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pcp_mtp_mock_runner():
|
||||
# set up pcp & mtp related buffers
|
||||
max_num_reqs = 4
|
||||
max_model_len = 4096
|
||||
max_num_tokens = 4096
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.device = 'cpu'
|
||||
mock_runner.pin_memory = False
|
||||
|
||||
# Init model_runner pcp_mtp related buffers
|
||||
mock_runner.query_start_loc_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_reqs + 1, dtype=torch.int32)
|
||||
|
||||
positions_buff = torch.zeros(max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
mock_runner.positions_pcp_full = positions_buff
|
||||
mock_runner.positions_pcp_full_np = positions_buff.numpy()
|
||||
|
||||
mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_tokens, dtype=torch.int32)
|
||||
mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer(
|
||||
mock_runner, max_num_reqs, dtype=torch.int32)
|
||||
mock_runner.decode_threshold = 1
|
||||
|
||||
mock_runner.arange_np = np.arange(max_model_len)
|
||||
mock_runner.input_batch = MagicMock()
|
||||
mock_runner.input_batch.num_computed_tokens_cpu = \
|
||||
np.zeros(max_num_reqs, dtype=np.int32)
|
||||
token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
mock_runner.input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
|
||||
mock_runner.input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
|
||||
return mock_runner
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"req_ids, num_computed_tokens," \
|
||||
"token_ids_tensor_list," \
|
||||
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
|
||||
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
|
||||
[
|
||||
# prefill
|
||||
(
|
||||
['0'], np.array([0]),
|
||||
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
|
||||
1, 6, {'0': 6},
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344]),
|
||||
torch.tensor([0, 6])
|
||||
),
|
||||
# decode
|
||||
(
|
||||
['0'], np.array([6]),
|
||||
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
|
||||
1, 2, {'0': 2},
|
||||
torch.tensor([88907, 0]),
|
||||
torch.tensor([0, 2])
|
||||
),
|
||||
# decode + prefill
|
||||
(
|
||||
['0', '1'], np.array([6, 0]),
|
||||
[
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
||||
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
||||
],
|
||||
2, 12, {'0': 2, '1': 10},
|
||||
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
||||
torch.tensor([0, 2, 12])
|
||||
),
|
||||
# decodes + prefills
|
||||
(
|
||||
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
|
||||
[
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
||||
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
|
||||
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
|
||||
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
|
||||
],
|
||||
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
|
||||
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
|
||||
0, 671, 5335, 1469, 7539, 305, 6397]),
|
||||
torch.tensor([0, 2, 4, 12, 19])
|
||||
),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_generate_pcp_mtp_input(
|
||||
pcp_mtp_mock_runner,
|
||||
req_ids,
|
||||
num_computed_tokens,
|
||||
token_ids_tensor_list,
|
||||
num_reqs,
|
||||
total_num_scheduled_tokens,
|
||||
num_scheduled_tokens,
|
||||
target_input_ids_pcp_full,
|
||||
target_query_start_loc_pcp_full,
|
||||
):
|
||||
mock_runner = pcp_mtp_mock_runner
|
||||
token_ids_cpu_tensor = mock_runner.input_batch.token_ids_cpu_tensor
|
||||
|
||||
# Set input_batch
|
||||
mock_runner.input_batch.req_ids = req_ids
|
||||
mock_runner.input_batch.num_computed_tokens_cpu[:num_computed_tokens.
|
||||
size] = num_computed_tokens
|
||||
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
||||
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
||||
|
||||
NPUModelRunner._generate_pcp_mtp_input(mock_runner, num_reqs,
|
||||
total_num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
assert torch.equal(
|
||||
mock_runner.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
||||
target_input_ids_pcp_full)
|
||||
assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
||||
target_query_start_loc_pcp_full)
|
||||
322
tests/ut/worker/test_pcp_manager.py
Normal file
322
tests/ut/worker/test_pcp_manager.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_ascend.worker.pcp_utils import PCPManager
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
|
||||
[
|
||||
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
|
||||
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
|
||||
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
|
||||
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
|
||||
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
|
||||
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
|
||||
])
|
||||
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
num_decodes, use_mla, total_tokens,
|
||||
expect_not_none):
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.model_config = MagicMock()
|
||||
vllm_config.model_config.use_mla = use_mla
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = 64
|
||||
vllm_config.speculative_config.num_speculative_tokens = 0
|
||||
|
||||
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
||||
pcp_rank=0,
|
||||
dcp_world_size=dcp_size,
|
||||
dcp_rank=0,
|
||||
max_buffer_num_tokens=10000,
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
pin_memory=False)
|
||||
input_batch = MagicMock()
|
||||
input_batch.num_reqs = num_reqs
|
||||
|
||||
num_computed_tokens = []
|
||||
num_prompt_tokens = []
|
||||
num_tokens = []
|
||||
|
||||
for i in range(num_reqs):
|
||||
if i < num_decodes:
|
||||
num_computed_tokens.append(query_lens[i])
|
||||
num_prompt_tokens.append(query_lens[i] // 2)
|
||||
num_tokens.append(query_lens[i])
|
||||
else:
|
||||
num_computed_tokens.append(0)
|
||||
num_prompt_tokens.append(query_lens[i])
|
||||
num_tokens.append(query_lens[i])
|
||||
|
||||
input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens)
|
||||
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
||||
input_batch.num_tokens = torch.tensor(num_tokens)
|
||||
|
||||
query_lens = torch.tensor(query_lens)
|
||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
|
||||
input_batch)
|
||||
|
||||
if not expect_not_none:
|
||||
assert result is None, f"Expected to return None, but got {type(result)}"
|
||||
else:
|
||||
assert result is not None, "Expected to return a metadata object, but got None."
|
||||
|
||||
assert hasattr(result, 'num_actual_tokens_pcp_padded')
|
||||
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
|
||||
|
||||
if pcp_size > 1:
|
||||
assert hasattr(result, 'pcp_allgather_restore_idx')
|
||||
|
||||
has_prefill_requests = (num_reqs - num_decodes) > 0
|
||||
if has_prefill_requests:
|
||||
assert hasattr(result, 'q_head_idx_tensor')
|
||||
assert hasattr(result, 'q_tail_idx_tensor')
|
||||
assert hasattr(result, 'q_full_idx')
|
||||
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
|
||||
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
|
||||
assert hasattr(result, 'attn_mask_seqlens')
|
||||
assert hasattr(result, 'head_attn_nomask_seqlens')
|
||||
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
||||
|
||||
if hasattr(result, 'pcp_prefill_mask'
|
||||
) and result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048, 2048)
|
||||
else:
|
||||
if hasattr(result, 'pcp_prefill_mask'):
|
||||
if result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048,
|
||||
2048)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
||||
[
|
||||
# Case 1: prefill only
|
||||
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
||||
|
||||
# # Case 2: mix prefill and decode
|
||||
([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]),
|
||||
|
||||
# # Case 3: request which need to be padded
|
||||
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
||||
|
||||
# Case 4: single request
|
||||
([10], 1, [0], [10], 4, 0, [4]),
|
||||
])
|
||||
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
num_prompt_tokens, pcp_size, pcp_rank,
|
||||
expected_pcp_tokens):
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.model_config = MagicMock()
|
||||
vllm_config.speculative_config.num_speculative_tokens = 0
|
||||
|
||||
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
||||
pcp_rank=0,
|
||||
dcp_world_size=1,
|
||||
dcp_rank=0,
|
||||
max_buffer_num_tokens=10000,
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
pin_memory=False)
|
||||
input_batch = MagicMock()
|
||||
input_batch.num_reqs = num_reqs
|
||||
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens,
|
||||
dtype=np.int32)
|
||||
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
||||
arange_np = np.arange(10000)
|
||||
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
||||
np.array(tokens), arange_np, num_reqs, 1)
|
||||
|
||||
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
||||
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
||||
|
||||
total_pcp_tokens: int = np.sum(pcp_tokens_result)
|
||||
assert positions_result.shape == (total_pcp_tokens,), \
|
||||
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
|
||||
[
|
||||
# without pcp and dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
|
||||
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
|
||||
# pcp
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
|
||||
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
|
||||
# dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
|
||||
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
|
||||
# pcp + dcp
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
|
||||
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
|
||||
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
|
||||
# specify interleave_size
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
|
||||
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
|
||||
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
|
||||
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
|
||||
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
|
||||
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
|
||||
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
|
||||
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_get_cp_local_seq_lens(
|
||||
seq_lens,
|
||||
pcp_world_size,
|
||||
dcp_world_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
target,
|
||||
):
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.model_config = MagicMock()
|
||||
vllm_config.speculative_config.num_speculative_tokens = 0
|
||||
pcp_manager = PCPManager(pcp_world_size=pcp_world_size,
|
||||
pcp_rank=0,
|
||||
dcp_world_size=dcp_world_size,
|
||||
dcp_rank=0,
|
||||
max_buffer_num_tokens=10000,
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
pin_memory=False)
|
||||
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
|
||||
dcp_world_size,
|
||||
cp_kv_cache_interleave_size)
|
||||
assert torch.equal(ret, target)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"req_ids, num_computed_tokens," \
|
||||
"token_ids_tensor_list," \
|
||||
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
|
||||
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
|
||||
[
|
||||
# prefill
|
||||
(
|
||||
['0'], np.array([0]),
|
||||
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
|
||||
1, 6, {'0': 6},
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344]),
|
||||
torch.tensor([0, 6])
|
||||
),
|
||||
# decode
|
||||
(
|
||||
['0'], np.array([6]),
|
||||
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
|
||||
1, 2, {'0': 2},
|
||||
torch.tensor([88907, 0]),
|
||||
torch.tensor([0, 2])
|
||||
),
|
||||
# decode + prefill
|
||||
(
|
||||
['0', '1'], np.array([6, 0]),
|
||||
[
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
||||
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
||||
],
|
||||
2, 12, {'0': 2, '1': 10},
|
||||
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
||||
torch.tensor([0, 2, 12])
|
||||
),
|
||||
# decodes + prefills
|
||||
(
|
||||
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
|
||||
[
|
||||
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
||||
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
|
||||
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
|
||||
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
|
||||
],
|
||||
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
|
||||
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
|
||||
0, 671, 5335, 1469, 7539, 305, 6397]),
|
||||
torch.tensor([0, 2, 4, 12, 19])
|
||||
),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_generate_pcp_mtp_input(
|
||||
req_ids,
|
||||
num_computed_tokens,
|
||||
token_ids_tensor_list,
|
||||
num_reqs,
|
||||
total_num_scheduled_tokens,
|
||||
num_scheduled_tokens,
|
||||
target_input_ids_pcp_full,
|
||||
target_query_start_loc_pcp_full,
|
||||
):
|
||||
max_num_reqs = 4
|
||||
max_model_len = 4096
|
||||
max_num_tokens = 4096
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.model_config = MagicMock()
|
||||
vllm_config.speculative_config.num_speculative_tokens = 1
|
||||
vllm_config.scheduler_config.max_num_seqs = max_num_reqs
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max_model_len
|
||||
pcp_manager = PCPManager(pcp_world_size=2,
|
||||
pcp_rank=0,
|
||||
dcp_world_size=1,
|
||||
dcp_rank=0,
|
||||
max_buffer_num_tokens=max_num_tokens,
|
||||
max_num_reqs=max_num_reqs,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
pin_memory=False)
|
||||
arange_np = np.arange(max_model_len)
|
||||
input_batch = MagicMock()
|
||||
input_batch.num_computed_tokens_cpu = \
|
||||
np.zeros(max_num_reqs, dtype=np.int32)
|
||||
token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
|
||||
input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
|
||||
token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor
|
||||
|
||||
# Set input_batch
|
||||
input_batch.req_ids = req_ids
|
||||
input_batch.num_computed_tokens_cpu[:num_computed_tokens.
|
||||
size] = num_computed_tokens
|
||||
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
||||
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
||||
|
||||
pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
|
||||
num_scheduled_tokens, False,
|
||||
input_batch, arange_np)
|
||||
assert torch.equal(
|
||||
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
||||
target_input_ids_pcp_full)
|
||||
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
||||
target_query_start_loc_pcp_full)
|
||||
@@ -279,9 +279,9 @@ class EagleProposer(VllmEagleProposer):
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
|
||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
|
||||
@@ -179,7 +179,7 @@ class MtpProposer(EagleProposer):
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||
self.runner.long_seq_metadata
|
||||
self.runner.pcp_manager.long_seq_metadata
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
self.runner.input_batch.block_table[0].get_device_tensor()[
|
||||
:num_reqs * self.decode_threshold]
|
||||
@@ -286,9 +286,9 @@ class MtpProposer(EagleProposer):
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
|
||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
@@ -303,7 +303,7 @@ class MtpProposer(EagleProposer):
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
|
||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
@@ -751,8 +751,8 @@ class MtpProposer(EagleProposer):
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.
|
||||
pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
hidden_states, 0, self.runner.pcp_manager.
|
||||
pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]])
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -797,13 +797,13 @@ class MtpProposer(EagleProposer):
|
||||
# (_generate_pcp_mtp_input), and use updated slot_indices
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = torch.tensor(
|
||||
self.runner.cu_num_tokens_pcp_full,
|
||||
self.runner.pcp_manager.cu_num_tokens_pcp_full,
|
||||
dtype=torch.int32).to(
|
||||
self.device) - ori_last_token_indices - 1
|
||||
num_accept_tokens = \
|
||||
query_lens_d.to(self.device) - num_reject_tokens
|
||||
ori_seq_len = attn_metadata_i.seq_lens
|
||||
mtp_slot_mapping = self.runner.mtp_slot_pad
|
||||
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
|
||||
|
||||
# slot_mapping index base offset:
|
||||
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
|
||||
@@ -889,7 +889,7 @@ class MtpProposer(EagleProposer):
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update local seq_len and batch_seq_mask
|
||||
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
|
||||
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
|
||||
ori_seq_len + step + 1,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
|
||||
@@ -24,7 +24,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Manager
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -78,8 +78,7 @@ from vllm.v1.worker.utils import AttentionGroup
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
@@ -109,6 +108,7 @@ from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
|
||||
lmhead_tp_enable, maybe_trans_nz,
|
||||
set_weight_prefetch_method)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
from vllm_ascend.worker.pcp_utils import PCPManager
|
||||
|
||||
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
||||
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
|
||||
@@ -202,6 +202,26 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = (self.max_num_tokens +
|
||||
self.max_num_reqs * 2 * self.pcp_size)
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.pin_memory,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens,
|
||||
dtype=torch.int64)
|
||||
self.sampler = AscendSampler()
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
@@ -262,32 +282,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
|
||||
self.uniform_decode_query_len)
|
||||
set_mc2_mask(vllm_config, self.device)
|
||||
self.pcp_allgather_restore_idx = torch.zeros(
|
||||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||||
[] for _ in range(self.pcp_size)
|
||||
]
|
||||
|
||||
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
||||
self.pcp_padded_slot_mapping = torch.zeros(
|
||||
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.int32)
|
||||
self.query_start_loc_pcp_full = self._make_buffer(
|
||||
self.max_num_reqs + 1, dtype=torch.int32)
|
||||
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.decode_threshold = 1 + (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
@@ -359,6 +353,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
self.reorder_batch_threshold: int | None = None
|
||||
self.long_seq_metadata = None
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
self.num_sms = None
|
||||
@@ -508,49 +503,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
return self.attn_mask_builder.get_mla_mask(self.dtype)
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
def generate_kv_idx(self, scheduler_output):
|
||||
if not self.pcp_size > 1:
|
||||
return
|
||||
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
|
||||
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
is_prefill = self.input_batch.num_computed_tokens_cpu[
|
||||
i] < self.input_batch.num_prompt_tokens[i]
|
||||
if is_prefill:
|
||||
num_cp_padded_scheduled_tokens = cdiv(
|
||||
num_scheduled_tokens,
|
||||
2 * self.pcp_size) * (2 * self.pcp_size)
|
||||
full_indices = list(
|
||||
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||
self.pcp_size * self.dcp_size * self.max_num_reqs))
|
||||
chunk_size = num_cp_padded_scheduled_tokens // (2 *
|
||||
self.pcp_size)
|
||||
num_added_recover_tokens = len(
|
||||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
|
||||
for rank in range(self.pcp_size):
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
full_indices[rank * chunk_size +
|
||||
num_added_recover_tokens:(rank + 1) *
|
||||
chunk_size + num_added_recover_tokens])
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
full_indices[num_cp_padded_scheduled_tokens -
|
||||
(rank + 1) * chunk_size +
|
||||
num_added_recover_tokens:
|
||||
num_cp_padded_scheduled_tokens -
|
||||
rank * chunk_size +
|
||||
num_added_recover_tokens])
|
||||
|
||||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||||
np.concatenate(
|
||||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||||
non_blocking=True)
|
||||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -574,43 +526,70 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
|
||||
positions_np = np.add(
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
)
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
total_num_pcp_pads = 0
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.generate_kv_idx(scheduler_output)
|
||||
tokens_before_update = tokens.copy()
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||||
total_num_pcp_pads = torch.sum(self.num_pcp_pads[:num_reqs]).item()
|
||||
else:
|
||||
position_pcp, pcp_unpad_mask = None, None
|
||||
self.num_pcp_pads[:num_reqs] = 0
|
||||
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||||
num_valid_tokens = np.array(tokens, dtype=np.int32)
|
||||
else:
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
for num_tokens, i in zip((tokens_before_update if self.
|
||||
pcp_size > 1 else tokens), req_ids)
|
||||
for num_tokens, i in zip(tokens, req_ids)
|
||||
],
|
||||
dtype=np.int32)
|
||||
# Get the attention state.
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
# Determine if it's a splitfuse batch
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
self.attn_mask = self._make_attention_mask(attn_state)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
# for pcp, prefill mtp should use origin scheduleroutput ,
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self.pcp_manager.generate_pcp_mtp_input(
|
||||
num_reqs, total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens, with_prefill,
|
||||
self.input_batch, self.arange_np, req_indices, positions_np,
|
||||
cu_num_tokens)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.pcp_manager.generate_kv_idx(scheduler_output,
|
||||
self.input_batch)
|
||||
num_scheduled_tokens[:
|
||||
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
|
||||
num_scheduled_tokens[:num_reqs],
|
||||
self.arange_np,
|
||||
self.input_batch.num_reqs,
|
||||
self.reorder_batch_threshold,
|
||||
)
|
||||
# Re-update after PCP split sequences.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens)
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
cu_num_tokens, _ = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||
np.add(
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
position_pcp[:total_num_scheduled_tokens],
|
||||
out=positions_np,
|
||||
)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
if (self.use_aclgraph and total_num_scheduled_tokens
|
||||
<= self.cudagraph_batch_sizes[-1]):
|
||||
# Add padding to the batch size.
|
||||
@@ -627,17 +606,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
# Get the attention state.
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
# Determine if it's a splitfuse batch
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
# Get info across DP ranks.
|
||||
@@ -646,7 +614,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
(maybe_padded_num_tokens, num_tokens_across_dp,
|
||||
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
|
||||
with_prefill)
|
||||
|
||||
self.with_prefill = with_prefill
|
||||
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
||||
# We should consider removing maybe_padded_num_tokens later
|
||||
num_input_tokens = maybe_padded_num_tokens
|
||||
@@ -655,24 +623,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
position_pcp[:total_num_scheduled_tokens],
|
||||
out=positions_np)
|
||||
else:
|
||||
self.positions.np[:total_num_scheduled_tokens] = positions_np
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
@@ -766,21 +718,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
self.seq_lens.gpu[num_reqs:].fill_(0)
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the NPU.
|
||||
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
|
||||
cu_num_tokens)
|
||||
self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions.copy_to_gpu()
|
||||
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_mask = self._make_attention_mask(attn_state)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
self.with_prefill = with_prefill
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Record the index of requests that should not be sampled,
|
||||
@@ -914,9 +856,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: Support prompt logprobs.
|
||||
spec_decode_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
logits_indices = torch.from_numpy(
|
||||
cu_num_tokens
|
||||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||||
logits_indices = self.pcp_manager.get_logits_indices(
|
||||
cu_num_tokens, num_reqs)
|
||||
logits_indices = logits_indices.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
else:
|
||||
@@ -938,8 +879,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens,
|
||||
self.num_pcp_pads[:num_reqs].numpy())
|
||||
num_draft_tokens,
|
||||
cu_num_tokens,
|
||||
num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs]
|
||||
if self.pcp_size > 1 else None)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||
@@ -961,23 +904,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self._generate_pcp_mtp_input(
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens, with_prefill,
|
||||
req_indices, positions_np, cu_num_tokens)
|
||||
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens)
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
|
||||
slot_mapping_size = (total_num_scheduled_tokens
|
||||
if self.pcp_size == 1 else
|
||||
total_num_scheduled_tokens * self.pcp_size -
|
||||
total_num_pcp_pads)
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
@@ -993,30 +923,30 @@ class NPUModelRunner(GPUModelRunner):
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
maybe_pcp_full_tokens = (
|
||||
num_input_tokens if self.pcp_size == 1 else
|
||||
total_num_scheduled_tokens * self.pcp_size -
|
||||
sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]))
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
slot_mapping_for_pcp = blk_table.slot_mapping.gpu[:
|
||||
long_seq_metadata
|
||||
.
|
||||
num_actual_tokens_pcp_padded]
|
||||
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
|
||||
assert pcp_unpad_mask is not None
|
||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||||
pcp_unpad_mask
|
||||
.
|
||||
shape[
|
||||
0]]
|
||||
pcp_padded_slot_mapping.fill_(-1)
|
||||
pcp_padded_slot_mapping[
|
||||
pcp_unpad_mask] = slot_mapping_for_pcp[:
|
||||
slot_mapping_size]
|
||||
slot_mapping_for_pcp[:long_seq_metadata.
|
||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||
blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
|
||||
slot_mapping_for_pcp
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:
|
||||
maybe_pcp_full_tokens]
|
||||
if self.pcp_size * self.dcp_size == 1:
|
||||
slot_mapping[
|
||||
total_num_scheduled_tokens:num_input_tokens].fill_(-1)
|
||||
slot_mapping = blk_table.slot_mapping.gpu
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, self.query_lens,
|
||||
self.attn_mask, self.input_batch)
|
||||
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
||||
slot_mapping = slot_mapping[:maybe_pcp_full_tokens]
|
||||
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
|
||||
total_num_scheduled_tokens,
|
||||
slot_mapping,
|
||||
)
|
||||
blk_table.slot_mapping.gpu[:self.pcp_manager.
|
||||
num_actual_tokens_pcp_padded] = slot_mapping
|
||||
|
||||
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
||||
# has been split to multiple parts, and there are 3 parts that is related to this
|
||||
@@ -1055,7 +985,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens.gpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=slot_mapping_size,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
num_input_tokens=num_input_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
@@ -1069,8 +999,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state=self.attn_state,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
max_seq_len=0)
|
||||
prefill_context_parallel_metadata=self.long_seq_metadata,
|
||||
max_seq_len=0,
|
||||
)
|
||||
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
@@ -1080,8 +1011,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
|
||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
|
||||
ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[:
|
||||
num_reqs]
|
||||
ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[:
|
||||
num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
@@ -1097,13 +1030,17 @@ class NPUModelRunner(GPUModelRunner):
|
||||
ori_query_lens[:num_decode_reqs], dim=0))
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
|
||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
assert self.long_seq_metadata is not None
|
||||
self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
|
||||
if 'pad_size' in locals() and pad_size > 0:
|
||||
ori_query_lens_cpu[-pad_size:] = \
|
||||
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
long_seq_metadata.max_query_len_pcp_full = \
|
||||
self.long_seq_metadata.max_query_len_pcp_full = \
|
||||
ori_query_lens_cpu.max().item()
|
||||
|
||||
|
||||
|
||||
if self.speculative_config and \
|
||||
self.spec_decode_common_attn_metadata is None:
|
||||
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||||
@@ -1193,19 +1130,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
pad_size = get_forward_context().pad_size
|
||||
if pad_size > 0:
|
||||
hidden_states = hidden_states[:-pad_size, :]
|
||||
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states[:self.num_actual_tokens_pcp_padded //
|
||||
self.pcp_size], 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0,
|
||||
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
return hidden_states
|
||||
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
|
||||
hidden_states)
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens):
|
||||
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
@@ -1231,7 +1161,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self,
|
||||
num_draft_tokens: np.ndarray,
|
||||
cu_num_scheduled_tokens: np.ndarray,
|
||||
num_pcp_pads: np.ndarray,
|
||||
num_pcp_pads: np.ndarray | None,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Inputs:
|
||||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||||
@@ -1846,7 +1776,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
|
||||
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
|
||||
num_tokens, self.query_lens, self.attn_mask,
|
||||
self.input_batch)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
@@ -2890,365 +2822,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
parent_module_name):
|
||||
super().capture_model()
|
||||
|
||||
def _update_tokens_for_pcp(self, tokens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
tokens = np.array(tokens, dtype=np.int32)
|
||||
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
|
||||
num_decode_tokens = sum(tokens[:num_decode_reqs])
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
tokens /
|
||||
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
|
||||
num_padded_scheduled_tokens[:num_decode_reqs] = (
|
||||
tokens[:num_decode_reqs] * self.pcp_size)
|
||||
self.num_pcp_pads[:num_reqs] = torch.tensor(
|
||||
num_padded_scheduled_tokens - tokens)
|
||||
cu_padded_tokens, pcp_padded_arange = \
|
||||
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
|
||||
unpad_mask = torch.from_numpy(
|
||||
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
|
||||
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
|
||||
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
|
||||
unpad_mask_decode[:, 0] = True
|
||||
unpad_mask_decode[:, 1:] = False
|
||||
|
||||
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
|
||||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||||
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
|
||||
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
|
||||
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
|
||||
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
|
||||
pcp_tokens)
|
||||
|
||||
def get_current_rank_positions(cu_tokens, rank):
|
||||
positions_start_loc = np.zeros_like(cu_tokens)
|
||||
positions_start_loc[1:] = cu_tokens[:-1]
|
||||
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
|
||||
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
|
||||
tail_start_loc = positions_start_loc + \
|
||||
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
|
||||
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
|
||||
np.repeat(head_start_loc, pcp_chunk_sizes)
|
||||
# Decode reqs do not have tail chunks.
|
||||
positions[~pcp_head_chunk_mask] = \
|
||||
pcp_chunk_arange[num_decode_tokens:] + \
|
||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
|
||||
return positions
|
||||
|
||||
positions = get_current_rank_positions(
|
||||
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
|
||||
# Decode tokens are duplicate and their positions always be 0.
|
||||
if num_decode_reqs > 0:
|
||||
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
|
||||
tokens[:num_decode_reqs])[1]
|
||||
|
||||
all_positions = [
|
||||
get_current_rank_positions(cu_padded_tokens, rank_i)
|
||||
for rank_i in range(self.pcp_size)
|
||||
]
|
||||
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
|
||||
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
|
||||
all_positions_tensor.float().argsort().long(), non_blocking=True)
|
||||
return pcp_tokens, positions, unpad_mask
|
||||
|
||||
def _get_cp_local_seq_lens(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
pcp_world_size: int = 1,
|
||||
dcp_world_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each (p/d)cp rank.
|
||||
"""
|
||||
num_requests = seq_lens.size(0)
|
||||
total_world_size = pcp_world_size * dcp_world_size
|
||||
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
|
||||
rank_offsets = (torch.arange(total_world_size,
|
||||
dtype=torch.int32).unsqueeze(0).repeat(
|
||||
num_requests, 1))
|
||||
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
|
||||
total_world_size * cp_kv_cache_interleave_size)
|
||||
remainder = seq_lens_tiled - base * total_world_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||
0,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = (base + remainder).reshape(
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
|
||||
# In dummy run num_reqs == 0, update it from seq_lens
|
||||
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
|
||||
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
|
||||
if self.pcp_size > 1 and self.speculative_config else self.query_lens
|
||||
num_decodes = (query_lens <= self.decode_threshold).sum().item()
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
|
||||
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
context_lens = np.concatenate(
|
||||
[decode_context_lens, prefill_context_lens])
|
||||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||||
[
|
||||
num_reqs * self.decode_threshold, self.pcp_size,
|
||||
self.dcp_size
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape.
|
||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||
# and keep prefill seq_lens unchanged.
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_cp_local_seq_lens(
|
||||
torch.tensor(context_lens) - decode_idx,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if num_decodes:
|
||||
num_decodes_flatten = \
|
||||
self.query_lens[:num_decodes].sum().item()
|
||||
if self.query_lens[:num_decodes].min().item(
|
||||
) == self.decode_threshold:
|
||||
decode_flatten_idx = list(range(num_decodes_flatten))
|
||||
else:
|
||||
decode_flatten_idx = []
|
||||
for req_id in range(num_decodes):
|
||||
offset = (req_id + 1) * self.decode_threshold
|
||||
decode_flatten_idx += \
|
||||
list(range(offset - self.query_lens[req_id], offset))
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
||||
if num_prefills:
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
(num_decodes + 1) * self.decode_threshold -
|
||||
1::self.decode_threshold])
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(
|
||||
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
numpy())
|
||||
if self.pcp_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||||
chunk_seqlens = []
|
||||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||||
q_req_offset = 0
|
||||
kv_req_offset = 0
|
||||
q_head_chunk_id = self.pcp_rank
|
||||
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
|
||||
for i, seq_len in enumerate(self.query_lens):
|
||||
if i < num_decodes:
|
||||
continue
|
||||
chunk_len = seq_len // 2
|
||||
chunk_seqlens.append(chunk_len)
|
||||
q_head_idx.extend(
|
||||
list(range(q_req_offset, q_req_offset + chunk_len)))
|
||||
kv_with_q_head_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_head_chunk_id)))
|
||||
kv_with_q_head_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_head_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_head_chunk_id + 1))))
|
||||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||||
q_head_chunk_id)
|
||||
|
||||
q_tail_idx.extend(
|
||||
list(
|
||||
range(q_req_offset + chunk_len,
|
||||
q_req_offset + chunk_len * 2)))
|
||||
kv_with_q_tail_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_tail_chunk_id)))
|
||||
kv_with_q_tail_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_tail_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_tail_chunk_id + 1))))
|
||||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||||
q_tail_chunk_id)
|
||||
|
||||
q_req_offset += seq_len
|
||||
kv_req_offset += seq_len * self.pcp_size
|
||||
|
||||
# Convert lists to tensors and move to device
|
||||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
||||
non_blocking=True)
|
||||
return tensor_npu
|
||||
|
||||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
||||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
||||
self.q_head_idx_tensor = q_head_idx_tensor
|
||||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||||
|
||||
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
|
||||
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
|
||||
torch.int32)
|
||||
self.q_full_idx = q_full_idx
|
||||
|
||||
self.kv_idx_names = {
|
||||
'kv_with_q_head_nomask_idx_tensor':
|
||||
kv_with_q_head_nomask_idx,
|
||||
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
|
||||
'kv_with_q_tail_nomask_idx_tensor':
|
||||
kv_with_q_tail_nomask_idx,
|
||||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||||
}
|
||||
for key, value in self.kv_idx_names.items():
|
||||
tensor_npu = _list_to_tensor(value, self.device)
|
||||
self.kv_idx_names[key] = tensor_npu
|
||||
|
||||
attn_mask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
||||
head_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
tail_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
pcp_prefill_mask = self.attn_mask
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
'attn_mask_seqlens': attn_mask_seqlens,
|
||||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
||||
'pcp_prefill_mask': pcp_prefill_mask
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_mask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_mask_idx_tensor']
|
||||
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
|
||||
'attn_mask_seqlens']
|
||||
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'head_attn_nomask_seqlens']
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
|
||||
def _generate_pcp_mtp_input(
|
||||
self,
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
with_prefill: bool = True,
|
||||
req_indices=None,
|
||||
positions_np=None,
|
||||
cu_num_tokens=None,
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
but mtp need to shift original input_ids before pcp splitting,
|
||||
so we record original input_ids here.
|
||||
"""
|
||||
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
|
||||
num_scheduled_tokens_pcp_full)
|
||||
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens_pcp_full)
|
||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||
self.query_start_loc_pcp_full.np[0] = 0
|
||||
self.query_start_loc_pcp_full.np[1:num_reqs +
|
||||
1] = cu_num_tokens_pcp_full
|
||||
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
|
||||
cumsums_offsets_pcp_full = np.repeat(
|
||||
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||||
num_scheduled_tokens_pcp_full)
|
||||
arange_pcp_full = self.arange_np[:
|
||||
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||||
positions_pcp_full_np = self.positions_pcp_full_np[:
|
||||
total_num_scheduled_tokens_pcp_full]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||||
arange_pcp_full,
|
||||
out=positions_pcp_full_np)
|
||||
token_indices_pcp_full = (
|
||||
positions_pcp_full_np +
|
||||
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full.
|
||||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
self.query_lens_pcp_full.copy_to_gpu()
|
||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
|
||||
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
||||
# For mtpx, pre-allocate mtp slot_mapping here
|
||||
if self.decode_threshold > 2 and not with_prefill:
|
||||
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
||||
num_tokens_mtp = \
|
||||
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
|
||||
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
|
||||
req_indices_split = np.array_split(req_indices,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
positions_split = np.array_split(positions_np,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
for req_idx in range(num_reqs):
|
||||
ori_req_indice = req_indices_split[req_idx]
|
||||
ori_position = positions_split[req_idx]
|
||||
req_indices_split[req_idx] = np.append(
|
||||
ori_req_indice,
|
||||
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
|
||||
positions_split[req_idx] = np.append(
|
||||
ori_position,
|
||||
np.arange(ori_position[-1] + 1,
|
||||
ori_position[-1] + self.decode_threshold - 1))
|
||||
req_indices_mtp = np.concatenate(req_indices_split)
|
||||
positions_mtp = np.concatenate(positions_split)
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices_mtp, positions_mtp)
|
||||
mtp_slot_ori = self.input_batch.block_table.block_tables[
|
||||
0].slot_mapping.cpu[:num_tokens_mtp]
|
||||
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
|
||||
unpad_mask[::self.pcp_size] = True
|
||||
mtp_slot_pad = \
|
||||
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
|
||||
mtp_slot_pad[unpad_mask] = mtp_slot_ori
|
||||
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
|
||||
|
||||
def _prepare_multimodal_fields(self):
|
||||
"""
|
||||
Ensures specific multimodal tensors are on CPU.
|
||||
|
||||
686
vllm_ascend/worker/pcp_utils.py
Normal file
686
vllm_ascend/worker/pcp_utils.py
Normal file
@@ -0,0 +1,686 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class PCPManager:
|
||||
"""
|
||||
Manager for Prefill Context Parallelism (PCP) metadata and buffers.
|
||||
|
||||
This manager encapsulates all PCP-related buffers and logic so that the
|
||||
ModelRunner can access them via `self.pcp_manager`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pcp_world_size: int,
|
||||
pcp_rank: int,
|
||||
dcp_world_size: int,
|
||||
dcp_rank: int,
|
||||
max_buffer_num_tokens: int,
|
||||
max_num_reqs: int,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
self.pcp_world_size = pcp_world_size
|
||||
self.pcp_world_rank = pcp_rank
|
||||
self.dcp_world_size = dcp_world_size
|
||||
self.dcp_world_rank = dcp_rank
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1 + (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0)
|
||||
self.vllm_config = vllm_config
|
||||
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.device = device
|
||||
self.pcp_allgather_restore_idx = CpuGpuBuffer(
|
||||
max_buffer_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.pcp_padded_slot_mapping = torch.full(
|
||||
(max_buffer_num_tokens, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int64)
|
||||
self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy()
|
||||
self.pcp_unpad_mask_cpu_tensor = torch.zeros(
|
||||
(max_buffer_num_tokens, ),
|
||||
device="cpu",
|
||||
dtype=torch.bool,
|
||||
)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
|
||||
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||||
[] for _ in range(self.pcp_world_size)
|
||||
]
|
||||
self.full_indices = list(
|
||||
range(self.max_num_tokens * self.pcp_world_size *
|
||||
self.dcp_world_size + self.pcp_world_size *
|
||||
self.dcp_world_size * self.max_num_reqs))
|
||||
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
|
||||
self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||
self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
arange_np: np.ndarray,
|
||||
cumsum_dtype: np.dtype | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the cumulative sum and batched arange of the given array.
|
||||
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
||||
# Equivalent to but faster than:
|
||||
# np.concatenate([np.arange(n) for n in num_scheduled_tokens])
|
||||
"""
|
||||
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype)
|
||||
total_num_tokens = cu_num_tokens[-1]
|
||||
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange = arange_np[:total_num_tokens] - cumsums_offsets
|
||||
|
||||
return cu_num_tokens, arange
|
||||
|
||||
def update_tokens_for_pcp(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
arange_np: np.ndarray,
|
||||
num_reqs: int,
|
||||
reorder_batch_threshold: int | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Update token counts and positions for Prefill Context Parallelism (PCP).
|
||||
|
||||
When using Prefill Context Parallelism, each request's prefill sequence is
|
||||
split across multiple PCP ranks. The splitting strategy used here is the
|
||||
"DualChunkSwap" style: each request's (padded) sequence is split into
|
||||
2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved
|
||||
head/tail pattern to balance load.
|
||||
|
||||
This function:
|
||||
- Computes how many tokens each request should be processed by the current
|
||||
PCP rank (pcp_tokens).
|
||||
- Computes the flattened positions of those tokens within the local
|
||||
padded buffer (pcp_positions).
|
||||
- Updates runner state arrays used to restore original order and mask out
|
||||
padded tokens after allgather:
|
||||
- self.num_pcp_pads_cpu: number of pads added per request
|
||||
- self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the
|
||||
padded allgather buffer
|
||||
- self.pcp_allgather_restore_idx: index array used to restore original
|
||||
ordering after per-rank allgather and interleaving.
|
||||
|
||||
Args:
|
||||
num_scheduled_tokens: 1D numpy array of length num_reqs containing
|
||||
the number of new tokens scheduled per request.
|
||||
arange_np: 1D numpy array of length max_buffer_num_tokens used for
|
||||
efficient batched arange operations.
|
||||
num_reqs: Total number of requests in the batch.
|
||||
reorder_batch_threshold: Threshold for decode vs prefill requests.
|
||||
|
||||
Returns:
|
||||
Tuple (pcp_tokens, pcp_positions):
|
||||
- pcp_tokens: number of tokens per request that this PCP rank will
|
||||
actually process (after splitting / replication).
|
||||
- pcp_positions: flattened positions for those tokens on this rank,
|
||||
used to build the positions buffer for the model.
|
||||
|
||||
Example:
|
||||
>>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp.
|
||||
>>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7])
|
||||
>>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5])
|
||||
>>> Meanwhile, the following results are same for each pcp rank
|
||||
>>> self.num_pcp_pads_cpu
|
||||
[1, 3, 0]
|
||||
>>> self.pcp_unpad_mask_cpu
|
||||
[True, False, True, True, True, True, True, False, False,
|
||||
False, True, True, True, True, True, True, True, True]
|
||||
>>> self.pcp_allgather_resotre_idx
|
||||
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
|
||||
"""
|
||||
|
||||
assert reorder_batch_threshold is not None, (
|
||||
"PCP depends on reorder batch to split decode and prefill requests."
|
||||
)
|
||||
num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold)
|
||||
num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs])
|
||||
|
||||
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
|
||||
# We first pad each request's token count up to that multiple.
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
num_scheduled_tokens / (2 * self.pcp_world_size)).astype(
|
||||
np.int32) * (2 * self.pcp_world_size)
|
||||
|
||||
# PCP does not split decode requests. For decode requests, we instead
|
||||
# duplicate the scheduled tokens across the pcp_world_size ranks.
|
||||
num_padded_scheduled_tokens[:num_decode_reqs] = (
|
||||
num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size)
|
||||
|
||||
# Record how many pads were added per request (padded - original).
|
||||
self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens -
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_padded_tokens: cumulative sum of padded token counts,
|
||||
# pcp_padded_arange: per-request arange flattened for padded tokens.
|
||||
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(
|
||||
num_padded_scheduled_tokens, arange_np)
|
||||
# Build the mask that marks which positions in the padded allgather buffer
|
||||
# correspond to real (unpadded) tokens.
|
||||
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
|
||||
pcp_padded_arange < np.repeat(num_scheduled_tokens,
|
||||
num_padded_scheduled_tokens))
|
||||
unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens *
|
||||
self.pcp_world_size]
|
||||
unpad_mask_decode = unpad_mask_decode.reshape(
|
||||
[-1, self.pcp_world_size])
|
||||
unpad_mask_decode[:, 0] = True
|
||||
unpad_mask_decode[:, 1:] = False
|
||||
pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size
|
||||
|
||||
# Compute per-request "chunk sizes" for the head/tail splitting.
|
||||
# For prefill requests, we further split the pcp_tokens into two chunks
|
||||
# (head and tail). For decode requests, the chunk equals pcp_tokens.
|
||||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||||
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
|
||||
|
||||
# Build arange-style helpers for pcp tokens and chunk sizes:
|
||||
# - pcp_arange gives indices repeated for each token in pcp_tokens
|
||||
# - pcp_chunk_arange gives indices repeated for each position inside chunks
|
||||
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np)
|
||||
_, pcp_chunk_arange = self._get_cumsum_and_arange(
|
||||
pcp_chunk_sizes, arange_np)
|
||||
|
||||
# Mask that marks whether a position belongs to the head chunk (True)
|
||||
# or the tail chunk (False). For decode requests, tail chunk won't exist
|
||||
# and is handled specially below.
|
||||
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
|
||||
pcp_tokens)
|
||||
|
||||
def get_current_rank_positions(positions_start_loc: int | np.ndarray,
|
||||
rank: int):
|
||||
"""
|
||||
Compute flattened positions for the given rank with a given start
|
||||
offset for each request (positions_start_loc).
|
||||
|
||||
- For head chunks: start at positions_start_loc + rank * chunk_size.
|
||||
- For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank -
|
||||
1) * chunk_size.
|
||||
- For decode requests: no tail chunks; their positions are filled from the
|
||||
contiguous (unpadded) `tokens` arange instead (handled after).
|
||||
"""
|
||||
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
|
||||
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
|
||||
tail_start_loc = (
|
||||
positions_start_loc +
|
||||
(2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes)
|
||||
# Fill head positions using chunk arange offset by head_start_loc.
|
||||
positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(
|
||||
head_start_loc, pcp_chunk_sizes)
|
||||
# Fill tail positions. Note decode requests do not have tail chunks,
|
||||
# so the tail filling is only for prefill positions.
|
||||
positions[~pcp_head_chunk_mask] = (
|
||||
pcp_chunk_arange[num_decode_tokens:] +
|
||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:])
|
||||
return positions
|
||||
|
||||
positions = get_current_rank_positions(0, self.pcp_world_rank)
|
||||
# Decode tokens are duplicated only after AG. But their positions are
|
||||
# same without prefill context parallel.
|
||||
if num_decode_reqs > 0:
|
||||
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens[:num_decode_reqs], arange_np)[1]
|
||||
|
||||
# Build the restore index used after allgather.
|
||||
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
|
||||
padded_pos_start_loc[0] = 0
|
||||
all_positions_lst = [
|
||||
get_current_rank_positions(padded_pos_start_loc, rank_i)
|
||||
for rank_i in range(self.pcp_world_size)
|
||||
]
|
||||
all_positions = np.concatenate(all_positions_lst)
|
||||
self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = (
|
||||
all_positions.argsort())
|
||||
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
|
||||
|
||||
return (
|
||||
pcp_tokens[:num_reqs],
|
||||
positions,
|
||||
)
|
||||
|
||||
def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int):
|
||||
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
|
||||
self.num_pcp_pads_cpu_tensor[:num_reqs] - 1)
|
||||
|
||||
def get_discard_request_mask(
|
||||
self,
|
||||
num_computed_tokens_cpu: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_reqs: int,
|
||||
num_tokens_np: np.ndarray,
|
||||
):
|
||||
return (num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens * self.pcp_world_size -
|
||||
self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np
|
||||
|
||||
def get_padded_slot_mapping(self, num_tokens: int,
|
||||
slot_mapping: torch.Tensor):
|
||||
# After pcp allgather and restore, there are padded tokens in kv,
|
||||
# so we need pad slotmapping for alignment.
|
||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens *
|
||||
self.
|
||||
pcp_world_size]
|
||||
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens *
|
||||
self.pcp_world_size]
|
||||
pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping
|
||||
return pcp_padded_slot_mapping
|
||||
|
||||
def get_restore_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
|
||||
# ignores the padding from CUDA Graph.
|
||||
from vllm.distributed.parallel_state import get_pcp_group
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states[:self.num_actual_tokens_pcp_padded //
|
||||
self.pcp_world_size],
|
||||
0,
|
||||
)
|
||||
restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states.
|
||||
shape[0]]
|
||||
return torch.index_select(
|
||||
hidden_states,
|
||||
0,
|
||||
restore_idx,
|
||||
)
|
||||
|
||||
def generate_pcp_mtp_input(
|
||||
self,
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
with_prefill: bool = True,
|
||||
input_batch=None,
|
||||
arange_np=None,
|
||||
req_indices=None,
|
||||
positions_np=None,
|
||||
cu_num_tokens=None,
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
but mtp need to shift original input_ids before pcp splitting,
|
||||
so we record original input_ids here.
|
||||
"""
|
||||
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
|
||||
num_scheduled_tokens_pcp_full)
|
||||
req_indices_pcp_full = np.repeat(arange_np[:num_reqs],
|
||||
num_scheduled_tokens_pcp_full)
|
||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||
self.query_start_loc_pcp_full.np[0] = 0
|
||||
self.query_start_loc_pcp_full.np[1:num_reqs +
|
||||
1] = cu_num_tokens_pcp_full
|
||||
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
|
||||
cumsums_offsets_pcp_full = np.repeat(
|
||||
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||||
num_scheduled_tokens_pcp_full)
|
||||
arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||||
positions_pcp_full_np = self.positions_pcp_full_np[:
|
||||
total_num_scheduled_tokens_pcp_full]
|
||||
np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||||
arange_pcp_full,
|
||||
out=positions_pcp_full_np)
|
||||
token_indices_pcp_full = (
|
||||
positions_pcp_full_np +
|
||||
req_indices_pcp_full * input_batch.token_ids_cpu.shape[1])
|
||||
torch.index_select(input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full.
|
||||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
self.query_lens_pcp_full.copy_to_gpu()
|
||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||
self.input_ids_pcp_full.copy_to_gpu(
|
||||
total_num_scheduled_tokens_pcp_full)
|
||||
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
||||
# For mtpx, pre-allocate mtp slot_mapping here
|
||||
if self.decode_threshold > 2 and not with_prefill:
|
||||
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
|
||||
num_tokens_mtp = \
|
||||
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
|
||||
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
|
||||
req_indices_split = np.array_split(req_indices,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
positions_split = np.array_split(positions_np,
|
||||
cu_num_tokens)[:num_reqs]
|
||||
for req_idx in range(num_reqs):
|
||||
ori_req_indice = req_indices_split[req_idx]
|
||||
ori_position = positions_split[req_idx]
|
||||
req_indices_split[req_idx] = np.append(
|
||||
ori_req_indice,
|
||||
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
|
||||
positions_split[req_idx] = np.append(
|
||||
ori_position,
|
||||
np.arange(ori_position[-1] + 1,
|
||||
ori_position[-1] + self.decode_threshold - 1))
|
||||
req_indices_mtp = np.concatenate(req_indices_split)
|
||||
positions_mtp = np.concatenate(positions_split)
|
||||
input_batch.block_table.compute_slot_mapping(
|
||||
req_indices_mtp, positions_mtp)
|
||||
mtp_slot_ori = input_batch.block_table.block_tables[
|
||||
0].slot_mapping.cpu[:num_tokens_mtp]
|
||||
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
|
||||
unpad_mask[::self.pcp_world_size] = True
|
||||
mtp_slot_pad = \
|
||||
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
|
||||
mtp_slot_pad[unpad_mask] = mtp_slot_ori
|
||||
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
|
||||
|
||||
def _get_cp_local_seq_lens(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
pcp_world_size: int = 1,
|
||||
dcp_world_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each (p/d)cp rank.
|
||||
"""
|
||||
num_requests = seq_lens.size(0)
|
||||
total_world_size = pcp_world_size * dcp_world_size
|
||||
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
|
||||
rank_offsets = (torch.arange(total_world_size,
|
||||
dtype=torch.int32).unsqueeze(0).repeat(
|
||||
num_requests, 1))
|
||||
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
|
||||
total_world_size * cp_kv_cache_interleave_size)
|
||||
remainder = seq_lens_tiled - base * total_world_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||
0,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = (base + remainder).reshape(
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def generate_kv_idx(self, scheduler_output, input_batch):
|
||||
if not self.pcp_world_size > 1:
|
||||
return
|
||||
self.cp_kv_recover_idx_for_chunk = [[]
|
||||
for _ in range(self.pcp_world_size)
|
||||
]
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
is_prefill = input_batch.num_computed_tokens_cpu[
|
||||
i] < input_batch.num_prompt_tokens[i]
|
||||
if is_prefill:
|
||||
num_cp_padded_scheduled_tokens = cdiv(
|
||||
num_scheduled_tokens,
|
||||
2 * self.pcp_world_size) * (2 * self.pcp_world_size)
|
||||
chunk_size = num_cp_padded_scheduled_tokens // (
|
||||
2 * self.pcp_world_size)
|
||||
num_added_recover_tokens = len(
|
||||
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size
|
||||
for rank in range(self.pcp_world_size):
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
self.full_indices[rank * chunk_size +
|
||||
num_added_recover_tokens:(rank + 1) *
|
||||
chunk_size +
|
||||
num_added_recover_tokens])
|
||||
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||
self.full_indices[num_cp_padded_scheduled_tokens -
|
||||
(rank + 1) * chunk_size +
|
||||
num_added_recover_tokens:
|
||||
num_cp_padded_scheduled_tokens -
|
||||
rank * chunk_size +
|
||||
num_added_recover_tokens])
|
||||
|
||||
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||||
np.concatenate(
|
||||
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||||
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||||
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||||
non_blocking=True)
|
||||
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||
attn_mask, input_batch):
|
||||
from vllm_ascend.attention.utils import \
|
||||
AscendPrefillContextParallelMetadata
|
||||
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
||||
query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \
|
||||
if self.pcp_world_size > 1 and self.speculative_config else query_lens
|
||||
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
decode_context_lens = input_batch.num_tokens[:num_decodes]
|
||||
prefill_context_lens = input_batch.num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
context_lens = np.concatenate(
|
||||
[decode_context_lens, prefill_context_lens])
|
||||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||||
[
|
||||
num_reqs * self.decode_threshold, self.pcp_world_size,
|
||||
self.dcp_world_size
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape.
|
||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||
# and keep prefill seq_lens unchanged.
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_cp_local_seq_lens(
|
||||
torch.tensor(context_lens) - decode_idx,
|
||||
self.pcp_world_size,
|
||||
self.dcp_world_size,
|
||||
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if num_decodes:
|
||||
num_decodes_flatten = \
|
||||
query_lens[:num_decodes].sum().item()
|
||||
if query_lens[:num_decodes].min().item(
|
||||
) == self.decode_threshold:
|
||||
decode_flatten_idx = list(range(num_decodes_flatten))
|
||||
else:
|
||||
decode_flatten_idx = []
|
||||
for req_id in range(num_decodes):
|
||||
offset = (req_id + 1) * self.decode_threshold
|
||||
decode_flatten_idx += \
|
||||
list(range(offset - query_lens[req_id], offset))
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
|
||||
if num_prefills:
|
||||
num_computed_tokens_of_pcp_dcp_list.append(
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
(num_decodes + 1) * self.decode_threshold -
|
||||
1::self.decode_threshold])
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(
|
||||
num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
numpy())
|
||||
if self.pcp_world_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||||
chunk_seqlens = []
|
||||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||||
q_req_offset = 0
|
||||
kv_req_offset = 0
|
||||
q_head_chunk_id = self.pcp_world_rank
|
||||
q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank
|
||||
for i, seq_len in enumerate(query_lens):
|
||||
if i < num_decodes:
|
||||
continue
|
||||
chunk_len = seq_len // 2
|
||||
chunk_seqlens.append(chunk_len)
|
||||
q_head_idx.extend(
|
||||
list(range(q_req_offset, q_req_offset + chunk_len)))
|
||||
kv_with_q_head_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_head_chunk_id)))
|
||||
kv_with_q_head_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_head_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_head_chunk_id + 1))))
|
||||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||||
q_head_chunk_id)
|
||||
|
||||
q_tail_idx.extend(
|
||||
list(
|
||||
range(q_req_offset + chunk_len,
|
||||
q_req_offset + chunk_len * 2)))
|
||||
kv_with_q_tail_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_tail_chunk_id)))
|
||||
kv_with_q_tail_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_tail_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_tail_chunk_id + 1))))
|
||||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||||
q_tail_chunk_id)
|
||||
|
||||
q_req_offset += seq_len
|
||||
kv_req_offset += seq_len * self.pcp_world_size
|
||||
|
||||
# Convert lists to tensors and move to device
|
||||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
||||
non_blocking=True)
|
||||
return tensor_npu
|
||||
|
||||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
||||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
||||
self.q_head_idx_tensor = q_head_idx_tensor
|
||||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||||
|
||||
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
|
||||
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
|
||||
torch.int32)
|
||||
self.q_full_idx = q_full_idx
|
||||
|
||||
self.kv_idx_names = {
|
||||
'kv_with_q_head_nomask_idx_tensor':
|
||||
kv_with_q_head_nomask_idx,
|
||||
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
|
||||
'kv_with_q_tail_nomask_idx_tensor':
|
||||
kv_with_q_tail_nomask_idx,
|
||||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||||
}
|
||||
for key, value in self.kv_idx_names.items():
|
||||
tensor_npu = _list_to_tensor(value, self.device)
|
||||
self.kv_idx_names[key] = tensor_npu
|
||||
|
||||
attn_mask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
||||
head_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
tail_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
pcp_prefill_mask = attn_mask
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
'attn_mask_seqlens': attn_mask_seqlens,
|
||||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
||||
'pcp_prefill_mask': pcp_prefill_mask
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_mask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_mask_idx_tensor']
|
||||
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
|
||||
'attn_mask_seqlens']
|
||||
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'head_attn_nomask_seqlens']
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
Reference in New Issue
Block a user