### What this PR does / why we need it? This PR fixes an accuracy issue that occurs when using Prefill/Decode Context Parallelism (PCP/DCP) in conjunction with speculative decoding (MTP). The issue is caused by an irregular attention mask shape when both features are enabled. The fix involves flattening the `block_table` for speculative decoding requests under PCP/DCP to ensure a regular attention mask. This PR also introduces a `use_cp` property for cleaner code and updates dummy runs to handle this scenario correctly. ### Does this PR introduce _any_ user-facing change? No. This is a bug fix that improves accuracy and should not have user-facing API changes. ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
519 lines
20 KiB
Python
519 lines
20 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm_ascend.worker.pcp_utils import PCPManager
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
|
|
[
|
|
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
|
|
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
|
|
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
|
|
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
|
|
])
|
|
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
|
num_decodes, use_mla, total_tokens,
|
|
expect_not_none):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.model_config.use_mla = use_mla
|
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = 64
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
|
|
num_computed_tokens = []
|
|
num_prompt_tokens = []
|
|
num_tokens = []
|
|
|
|
for i in range(num_reqs):
|
|
if i < num_decodes:
|
|
num_computed_tokens.append(query_lens[i])
|
|
num_prompt_tokens.append(query_lens[i] // 2)
|
|
num_tokens.append(query_lens[i])
|
|
else:
|
|
num_computed_tokens.append(0)
|
|
num_prompt_tokens.append(query_lens[i])
|
|
num_tokens.append(query_lens[i])
|
|
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
|
|
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
|
input_batch.num_tokens = torch.tensor(num_tokens)
|
|
num_scheduled_tokens = np.array(
|
|
query_lens) - input_batch.num_computed_tokens_cpu
|
|
|
|
query_lens = torch.tensor(query_lens)
|
|
result, _ = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
|
input_batch,
|
|
num_scheduled_tokens,
|
|
torch.tensor([]),
|
|
num_reqs_padded=num_reqs,
|
|
num_reqs=num_reqs)
|
|
|
|
if not expect_not_none:
|
|
assert result is None, f"Expected to return None, but got {type(result)}"
|
|
else:
|
|
assert result is not None, "Expected to return a metadata object, but got None."
|
|
|
|
assert hasattr(result, 'num_actual_tokens_pcp_padded')
|
|
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
|
|
|
|
if pcp_size > 1:
|
|
assert hasattr(result, 'pcp_allgather_restore_idx')
|
|
|
|
has_prefill_requests = (num_reqs - num_decodes) > 0
|
|
if has_prefill_requests:
|
|
assert hasattr(result, 'q_head_idx_tensor')
|
|
assert hasattr(result, 'q_tail_idx_tensor')
|
|
assert hasattr(result, 'q_full_idx')
|
|
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
|
|
assert hasattr(result, 'attn_mask_seqlens')
|
|
assert hasattr(result, 'head_attn_nomask_seqlens')
|
|
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
|
[
|
|
# Case 1: prefill only
|
|
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
|
|
|
# # Case 2: mix prefill and decode
|
|
([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]),
|
|
|
|
# # Case 3: request which need to be padded
|
|
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
|
|
|
# Case 4: single request
|
|
([10], 1, [0], [10], 4, 0, [4]),
|
|
])
|
|
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
|
num_prompt_tokens, pcp_size, pcp_rank,
|
|
expected_pcp_tokens):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
vllm_config.scheduler_config.max_num_seqs = 1000
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens,
|
|
dtype=np.int32)
|
|
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
|
arange_np = np.arange(10000)
|
|
num_scheduled_tokens = np.array(tokens)
|
|
pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs)
|
|
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
|
num_scheduled_tokens, arange_np)
|
|
|
|
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
|
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
|
|
|
total_pcp_tokens: int = np.sum(pcp_tokens_result)
|
|
assert positions_result.shape == (total_pcp_tokens,), \
|
|
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
|
|
[
|
|
# without pcp and dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
|
|
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
|
|
# pcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
|
|
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
|
|
# dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
|
|
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
|
|
# pcp + dcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
|
|
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
|
|
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
|
|
# specify interleave_size
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
|
|
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
|
|
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
|
|
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
|
|
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
|
|
]
|
|
)
|
|
# yapf: enable
|
|
def test_get_cp_local_seq_lens(
|
|
seq_lens,
|
|
pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size,
|
|
target,
|
|
):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_world_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_world_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size)
|
|
assert torch.equal(ret, target)
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"req_ids, num_computed_tokens," \
|
|
"token_ids_tensor_list," \
|
|
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
|
|
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
|
|
[
|
|
# prefill
|
|
(
|
|
['0'], np.array([0]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
|
|
1, 6, {'0': 6},
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344]),
|
|
torch.tensor([0, 6])
|
|
),
|
|
# decode
|
|
(
|
|
['0'], np.array([6]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
|
|
1, 2, {'0': 2},
|
|
torch.tensor([88907, 0]),
|
|
torch.tensor([0, 2])
|
|
),
|
|
# decode + prefill
|
|
(
|
|
['0', '1'], np.array([6, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
],
|
|
2, 12, {'0': 2, '1': 10},
|
|
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
torch.tensor([0, 2, 12])
|
|
),
|
|
# decodes + prefills
|
|
(
|
|
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
|
|
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
|
|
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
],
|
|
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
|
|
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
|
|
0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
torch.tensor([0, 2, 4, 12, 19])
|
|
),
|
|
])
|
|
# yapf: enable
|
|
def test_generate_pcp_mtp_input(
|
|
req_ids,
|
|
num_computed_tokens,
|
|
token_ids_tensor_list,
|
|
num_reqs,
|
|
total_num_scheduled_tokens,
|
|
num_scheduled_tokens,
|
|
target_input_ids_pcp_full,
|
|
target_query_start_loc_pcp_full,
|
|
):
|
|
max_num_reqs = 4
|
|
max_model_len = 4096
|
|
max_num_tokens = 4096
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 1
|
|
vllm_config.scheduler_config.max_num_seqs = max_num_reqs
|
|
vllm_config.scheduler_config.max_num_batched_tokens = max_model_len
|
|
pcp_manager = PCPManager(pcp_world_size=2,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=max_num_tokens,
|
|
max_num_reqs=max_num_reqs,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
arange_np = np.arange(max_model_len)
|
|
input_batch = MagicMock()
|
|
input_batch.num_computed_tokens_cpu = \
|
|
np.zeros(max_num_reqs, dtype=np.int32)
|
|
token_ids_cpu_tensor = torch.zeros(
|
|
(max_num_reqs, max_model_len),
|
|
device="cpu",
|
|
dtype=torch.int32,
|
|
)
|
|
input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
|
|
input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
|
|
token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor
|
|
|
|
# Set input_batch
|
|
input_batch.req_ids = req_ids
|
|
input_batch.num_computed_tokens_cpu[:num_computed_tokens.
|
|
size] = num_computed_tokens
|
|
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
|
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
|
|
|
pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs)
|
|
pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False,
|
|
input_batch, arange_np)
|
|
assert torch.equal(
|
|
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
|
target_input_ids_pcp_full)
|
|
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
|
target_query_start_loc_pcp_full)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_world_rank, split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs,"
|
|
"head_attn_nomask_seqlens, chunk_seqlens,"
|
|
"target_split_q_head, target_split_q_tail, target_head_seqlens, target_tail_seqlens",
|
|
[
|
|
# case1: pcp_world_rank=0
|
|
(0, [[10, 20, 30]], [[40, 50, 60]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32), [64], [
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
], [torch.tensor([40, 50, 60], dtype=torch.int32)], [
|
|
torch.tensor([[64], [0]], dtype=torch.int32)
|
|
], [torch.tensor([[64], [3]], dtype=torch.int32)]),
|
|
# case2: pcp_world_rank=1
|
|
(1, [[1, 2], [3, 4, 5]], [[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), [128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], [
|
|
torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)]),
|
|
# case3: pcp_world_rank=2
|
|
(2, [[11, 12, 13, 14], [15, 16]], [[17, 18, 19], [20, 21, 22, 23]],
|
|
torch.tensor([[256, 256], [512, 512]], dtype=torch.int32), [256, 256],
|
|
[torch.tensor([11, 12, 13, 14, 15, 16], dtype=torch.int32)], [
|
|
torch.tensor([17, 18, 19, 20, 21, 22, 23], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [4, 2]], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [3, 4]], dtype=torch.int32)]),
|
|
# case4: empty input
|
|
(
|
|
0,
|
|
[],
|
|
[],
|
|
torch.tensor([], dtype=torch.int32).reshape(2, 0),
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
# case5: single element input
|
|
(
|
|
0,
|
|
[[10]],
|
|
[[40]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32),
|
|
[64],
|
|
[torch.tensor([1, 2, 3], dtype=torch.int32)],
|
|
[torch.tensor([40], dtype=torch.int32)],
|
|
[torch.tensor([[64], [0]], dtype=torch.int32)],
|
|
[torch.tensor([[64], [1]], dtype=torch.int32)],
|
|
),
|
|
# case6: pcp_world_rank=3
|
|
(
|
|
3,
|
|
[[1, 2], [3, 4, 5]],
|
|
[[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32),
|
|
[128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)],
|
|
[torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
),
|
|
])
|
|
def test_split_nomask_idx_tensor_list(
|
|
pcp_world_rank, split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens,
|
|
chunk_seqlens, target_split_q_head, target_split_q_tail,
|
|
target_head_seqlens, target_tail_seqlens):
|
|
# Mock input data
|
|
mock_runner = MagicMock(spec=PCPManager)
|
|
mock_runner.device = "cpu"
|
|
mock_runner.pcp_world_rank = 0
|
|
mock_runner.kv_idx_names = {
|
|
"kv_with_q_head_nomask_idx_tensor":
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
}
|
|
|
|
mock_runner.pcp_world_rank = pcp_world_rank
|
|
|
|
# Mock output
|
|
mock_runner._split_multi_batch_kv_idx.side_effect = PCPManager._split_multi_batch_kv_idx.__get__(
|
|
mock_runner, PCPManager)
|
|
mock_runner._list_to_tensor.side_effect = PCPManager._list_to_tensor.__get__(
|
|
mock_runner, PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_nomask_idx_tensor_list(
|
|
mock_runner,
|
|
split_with_q_head_nomask_idx_reqs=split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs=
|
|
split_kv_with_q_tail_nomask_idx_reqs,
|
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
|
chunk_seqlens=chunk_seqlens)
|
|
split_q_head, split_q_tail, head_seqlens, tail_seqlens = result
|
|
|
|
# Assert the method call
|
|
assert len(split_q_head) == len(target_split_q_head)
|
|
for res, target in zip(split_q_head, target_split_q_head):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(split_q_tail) == len(target_split_q_tail)
|
|
for res, target in zip(split_q_tail, target_split_q_tail):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(head_seqlens) == len(target_head_seqlens)
|
|
for res, target in zip(head_seqlens, target_head_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
assert len(tail_seqlens) == len(target_tail_seqlens)
|
|
for res, target in zip(tail_seqlens, target_tail_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kv_nomask_idx_multi_batch, split_size, expected_merged_idx, expected_merged_len",
|
|
[
|
|
# case1: multiple batches + split size greater than batch length
|
|
(
|
|
[[0, 1, 2, 3, 4], [5, 6, 7]],
|
|
2,
|
|
# expected merged_split_kv_idx_3d
|
|
[[0, 1, 5, 6], [2, 3, 7], [4]],
|
|
# expected merged_split_kv_len_2d
|
|
[[2, 2], [2, 1], [1, 0]],
|
|
),
|
|
# case2: single batch + split size greater than batch length
|
|
(
|
|
[[0, 1, 2]],
|
|
5,
|
|
[[0, 1, 2]],
|
|
[[3]],
|
|
),
|
|
# case3: split size equals maximum batch length
|
|
(
|
|
[[0, 1, 2, 3], [5, 6]],
|
|
4,
|
|
[[0, 1, 2, 3, 5, 6]],
|
|
[[4, 2]],
|
|
),
|
|
# case4: Split size is 1 (minimum granularity split)
|
|
(
|
|
[[0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[1, 1], [1, 0]],
|
|
),
|
|
# case6: the batch contains an empty list
|
|
(
|
|
[[], [0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[0, 1, 1], [0, 1, 0]],
|
|
),
|
|
# case7: empty input
|
|
(
|
|
[],
|
|
2,
|
|
[],
|
|
[],
|
|
),
|
|
])
|
|
def test_split_multi_batch_kv_idx(
|
|
kv_nomask_idx_multi_batch,
|
|
split_size,
|
|
expected_merged_idx,
|
|
expected_merged_len,
|
|
):
|
|
# Mock input data
|
|
model_runner = MagicMock(spec=PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_multi_batch_kv_idx(
|
|
self=model_runner,
|
|
kv_nomask_idx_multi_batch=kv_nomask_idx_multi_batch,
|
|
split_size=split_size)
|
|
|
|
merged_split_kv_idx_3d, merged_split_kv_len_2d = result
|
|
|
|
# Assert the method call
|
|
assert len(merged_split_kv_idx_3d) == len(expected_merged_idx)
|
|
|
|
for t, (actual_seg, expected_seg) in enumerate(
|
|
zip(merged_split_kv_idx_3d, expected_merged_idx)):
|
|
assert actual_seg == expected_seg
|
|
|
|
assert len(merged_split_kv_len_2d) == len(expected_merged_len)
|
|
|
|
for t, (actual_len, expected_len) in enumerate(
|
|
zip(merged_split_kv_len_2d, expected_merged_len)):
|
|
assert actual_len == expected_len
|