### What this PR does / why we need it?
Since the PR (https://github.com/vllm-project/vllm/pull/32118) has
modified the criteria for judging Prefill and Decode requests in vLLM,
PCPManager needs to synchronize with this standard. As PCPManager
involves multiple calculations of PD request counts, this PR attempts to
consolidate the related logic and update the PD request count once per
batch.
### How was this patch tested?
```bash
pytest tests/e2e/multicard/4-cards/long_sequence/test_mtp.py
```
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
516 lines
20 KiB
Python
516 lines
20 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm_ascend.worker.pcp_utils import PCPManager
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
|
|
[
|
|
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
|
|
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
|
|
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
|
|
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
|
|
])
|
|
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
|
num_decodes, use_mla, total_tokens,
|
|
expect_not_none):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.model_config.use_mla = use_mla
|
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = 64
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
|
|
num_computed_tokens = []
|
|
num_prompt_tokens = []
|
|
num_tokens = []
|
|
|
|
for i in range(num_reqs):
|
|
if i < num_decodes:
|
|
num_computed_tokens.append(query_lens[i])
|
|
num_prompt_tokens.append(query_lens[i] // 2)
|
|
num_tokens.append(query_lens[i])
|
|
else:
|
|
num_computed_tokens.append(0)
|
|
num_prompt_tokens.append(query_lens[i])
|
|
num_tokens.append(query_lens[i])
|
|
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
|
|
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
|
input_batch.num_tokens = torch.tensor(num_tokens)
|
|
num_scheduled_tokens = np.array(
|
|
query_lens) - input_batch.num_computed_tokens_cpu
|
|
|
|
query_lens = torch.tensor(query_lens)
|
|
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
|
input_batch,
|
|
num_scheduled_tokens)
|
|
|
|
if not expect_not_none:
|
|
assert result is None, f"Expected to return None, but got {type(result)}"
|
|
else:
|
|
assert result is not None, "Expected to return a metadata object, but got None."
|
|
|
|
assert hasattr(result, 'num_actual_tokens_pcp_padded')
|
|
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
|
|
|
|
if pcp_size > 1:
|
|
assert hasattr(result, 'pcp_allgather_restore_idx')
|
|
|
|
has_prefill_requests = (num_reqs - num_decodes) > 0
|
|
if has_prefill_requests:
|
|
assert hasattr(result, 'q_head_idx_tensor')
|
|
assert hasattr(result, 'q_tail_idx_tensor')
|
|
assert hasattr(result, 'q_full_idx')
|
|
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
|
|
assert hasattr(result, 'attn_mask_seqlens')
|
|
assert hasattr(result, 'head_attn_nomask_seqlens')
|
|
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
|
[
|
|
# Case 1: prefill only
|
|
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
|
|
|
# # Case 2: mix prefill and decode
|
|
([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]),
|
|
|
|
# # Case 3: request which need to be padded
|
|
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
|
|
|
# Case 4: single request
|
|
([10], 1, [0], [10], 4, 0, [4]),
|
|
])
|
|
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
|
num_prompt_tokens, pcp_size, pcp_rank,
|
|
expected_pcp_tokens):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
vllm_config.scheduler_config.max_num_seqs = 1000
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens,
|
|
dtype=np.int32)
|
|
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
|
arange_np = np.arange(10000)
|
|
num_scheduled_tokens = np.array(tokens)
|
|
pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs)
|
|
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
|
num_scheduled_tokens, arange_np)
|
|
|
|
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
|
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
|
|
|
total_pcp_tokens: int = np.sum(pcp_tokens_result)
|
|
assert positions_result.shape == (total_pcp_tokens,), \
|
|
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
|
|
[
|
|
# without pcp and dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
|
|
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
|
|
# pcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
|
|
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
|
|
# dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
|
|
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
|
|
# pcp + dcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
|
|
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
|
|
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
|
|
# specify interleave_size
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
|
|
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
|
|
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
|
|
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
|
|
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
|
|
]
|
|
)
|
|
# yapf: enable
|
|
def test_get_cp_local_seq_lens(
|
|
seq_lens,
|
|
pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size,
|
|
target,
|
|
):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_world_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_world_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size)
|
|
assert torch.equal(ret, target)
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"req_ids, num_computed_tokens," \
|
|
"token_ids_tensor_list," \
|
|
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
|
|
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
|
|
[
|
|
# prefill
|
|
(
|
|
['0'], np.array([0]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
|
|
1, 6, {'0': 6},
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344]),
|
|
torch.tensor([0, 6])
|
|
),
|
|
# decode
|
|
(
|
|
['0'], np.array([6]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
|
|
1, 2, {'0': 2},
|
|
torch.tensor([88907, 0]),
|
|
torch.tensor([0, 2])
|
|
),
|
|
# decode + prefill
|
|
(
|
|
['0', '1'], np.array([6, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
],
|
|
2, 12, {'0': 2, '1': 10},
|
|
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
torch.tensor([0, 2, 12])
|
|
),
|
|
# decodes + prefills
|
|
(
|
|
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
|
|
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
|
|
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
],
|
|
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
|
|
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
|
|
0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
torch.tensor([0, 2, 4, 12, 19])
|
|
),
|
|
])
|
|
# yapf: enable
|
|
def test_generate_pcp_mtp_input(
|
|
req_ids,
|
|
num_computed_tokens,
|
|
token_ids_tensor_list,
|
|
num_reqs,
|
|
total_num_scheduled_tokens,
|
|
num_scheduled_tokens,
|
|
target_input_ids_pcp_full,
|
|
target_query_start_loc_pcp_full,
|
|
):
|
|
max_num_reqs = 4
|
|
max_model_len = 4096
|
|
max_num_tokens = 4096
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 1
|
|
vllm_config.scheduler_config.max_num_seqs = max_num_reqs
|
|
vllm_config.scheduler_config.max_num_batched_tokens = max_model_len
|
|
pcp_manager = PCPManager(pcp_world_size=2,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=max_num_tokens,
|
|
max_num_reqs=max_num_reqs,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
arange_np = np.arange(max_model_len)
|
|
input_batch = MagicMock()
|
|
input_batch.num_computed_tokens_cpu = \
|
|
np.zeros(max_num_reqs, dtype=np.int32)
|
|
token_ids_cpu_tensor = torch.zeros(
|
|
(max_num_reqs, max_model_len),
|
|
device="cpu",
|
|
dtype=torch.int32,
|
|
)
|
|
input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
|
|
input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
|
|
token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor
|
|
|
|
# Set input_batch
|
|
input_batch.req_ids = req_ids
|
|
input_batch.num_computed_tokens_cpu[:num_computed_tokens.
|
|
size] = num_computed_tokens
|
|
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
|
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
|
|
|
pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs)
|
|
pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False,
|
|
input_batch, arange_np)
|
|
assert torch.equal(
|
|
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
|
target_input_ids_pcp_full)
|
|
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
|
target_query_start_loc_pcp_full)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_world_rank, split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs,"
|
|
"head_attn_nomask_seqlens, chunk_seqlens,"
|
|
"target_split_q_head, target_split_q_tail, target_head_seqlens, target_tail_seqlens",
|
|
[
|
|
# case1: pcp_world_rank=0
|
|
(0, [[10, 20, 30]], [[40, 50, 60]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32), [64], [
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
], [torch.tensor([40, 50, 60], dtype=torch.int32)], [
|
|
torch.tensor([[64], [0]], dtype=torch.int32)
|
|
], [torch.tensor([[64], [3]], dtype=torch.int32)]),
|
|
# case2: pcp_world_rank=1
|
|
(1, [[1, 2], [3, 4, 5]], [[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), [128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], [
|
|
torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)]),
|
|
# case3: pcp_world_rank=2
|
|
(2, [[11, 12, 13, 14], [15, 16]], [[17, 18, 19], [20, 21, 22, 23]],
|
|
torch.tensor([[256, 256], [512, 512]], dtype=torch.int32), [256, 256],
|
|
[torch.tensor([11, 12, 13, 14, 15, 16], dtype=torch.int32)], [
|
|
torch.tensor([17, 18, 19, 20, 21, 22, 23], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [4, 2]], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [3, 4]], dtype=torch.int32)]),
|
|
# case4: empty input
|
|
(
|
|
0,
|
|
[],
|
|
[],
|
|
torch.tensor([], dtype=torch.int32).reshape(2, 0),
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
# case5: single element input
|
|
(
|
|
0,
|
|
[[10]],
|
|
[[40]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32),
|
|
[64],
|
|
[torch.tensor([1, 2, 3], dtype=torch.int32)],
|
|
[torch.tensor([40], dtype=torch.int32)],
|
|
[torch.tensor([[64], [0]], dtype=torch.int32)],
|
|
[torch.tensor([[64], [1]], dtype=torch.int32)],
|
|
),
|
|
# case6: pcp_world_rank=3
|
|
(
|
|
3,
|
|
[[1, 2], [3, 4, 5]],
|
|
[[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32),
|
|
[128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)],
|
|
[torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
),
|
|
])
|
|
def test_split_nomask_idx_tensor_list(
|
|
pcp_world_rank, split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens,
|
|
chunk_seqlens, target_split_q_head, target_split_q_tail,
|
|
target_head_seqlens, target_tail_seqlens):
|
|
# Mock input data
|
|
mock_runner = MagicMock(spec=PCPManager)
|
|
mock_runner.device = "cpu"
|
|
mock_runner.pcp_world_rank = 0
|
|
mock_runner.kv_idx_names = {
|
|
"kv_with_q_head_nomask_idx_tensor":
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
}
|
|
|
|
mock_runner.pcp_world_rank = pcp_world_rank
|
|
|
|
# Mock output
|
|
mock_runner._split_multi_batch_kv_idx.side_effect = PCPManager._split_multi_batch_kv_idx.__get__(
|
|
mock_runner, PCPManager)
|
|
mock_runner._list_to_tensor.side_effect = PCPManager._list_to_tensor.__get__(
|
|
mock_runner, PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_nomask_idx_tensor_list(
|
|
mock_runner,
|
|
split_with_q_head_nomask_idx_reqs=split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs=
|
|
split_kv_with_q_tail_nomask_idx_reqs,
|
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
|
chunk_seqlens=chunk_seqlens)
|
|
split_q_head, split_q_tail, head_seqlens, tail_seqlens = result
|
|
|
|
# Assert the method call
|
|
assert len(split_q_head) == len(target_split_q_head)
|
|
for res, target in zip(split_q_head, target_split_q_head):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(split_q_tail) == len(target_split_q_tail)
|
|
for res, target in zip(split_q_tail, target_split_q_tail):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(head_seqlens) == len(target_head_seqlens)
|
|
for res, target in zip(head_seqlens, target_head_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
assert len(tail_seqlens) == len(target_tail_seqlens)
|
|
for res, target in zip(tail_seqlens, target_tail_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kv_nomask_idx_multi_batch, split_size, expected_merged_idx, expected_merged_len",
|
|
[
|
|
# case1: multiple batches + split size greater than batch length
|
|
(
|
|
[[0, 1, 2, 3, 4], [5, 6, 7]],
|
|
2,
|
|
# expected merged_split_kv_idx_3d
|
|
[[0, 1, 5, 6], [2, 3, 7], [4]],
|
|
# expected merged_split_kv_len_2d
|
|
[[2, 2], [2, 1], [1, 0]],
|
|
),
|
|
# case2: single batch + split size greater than batch length
|
|
(
|
|
[[0, 1, 2]],
|
|
5,
|
|
[[0, 1, 2]],
|
|
[[3]],
|
|
),
|
|
# case3: split size equals maximum batch length
|
|
(
|
|
[[0, 1, 2, 3], [5, 6]],
|
|
4,
|
|
[[0, 1, 2, 3, 5, 6]],
|
|
[[4, 2]],
|
|
),
|
|
# case4: Split size is 1 (minimum granularity split)
|
|
(
|
|
[[0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[1, 1], [1, 0]],
|
|
),
|
|
# case6: the batch contains an empty list
|
|
(
|
|
[[], [0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[0, 1, 1], [0, 1, 0]],
|
|
),
|
|
# case7: empty input
|
|
(
|
|
[],
|
|
2,
|
|
[],
|
|
[],
|
|
),
|
|
])
|
|
def test_split_multi_batch_kv_idx(
|
|
kv_nomask_idx_multi_batch,
|
|
split_size,
|
|
expected_merged_idx,
|
|
expected_merged_len,
|
|
):
|
|
# Mock input data
|
|
model_runner = MagicMock(spec=PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_multi_batch_kv_idx(
|
|
self=model_runner,
|
|
kv_nomask_idx_multi_batch=kv_nomask_idx_multi_batch,
|
|
split_size=split_size)
|
|
|
|
merged_split_kv_idx_3d, merged_split_kv_len_2d = result
|
|
|
|
# Assert the method call
|
|
assert len(merged_split_kv_idx_3d) == len(expected_merged_idx)
|
|
|
|
for t, (actual_seg, expected_seg) in enumerate(
|
|
zip(merged_split_kv_idx_3d, expected_merged_idx)):
|
|
assert actual_seg == expected_seg
|
|
|
|
assert len(merged_split_kv_len_2d) == len(expected_merged_len)
|
|
|
|
for t, (actual_len, expected_len) in enumerate(
|
|
zip(merged_split_kv_len_2d, expected_merged_len)):
|
|
assert actual_len == expected_len
|