### What this PR does / why we need it?
Fixed the issue where the PCP and MTP services could not be started due
to asynchronous scheduling.
After the pcp, mtp, and asynchronous scheduling functions are enabled,
the service is suspended because of a shape mismatch after a curl
request is sent. This PR resolves this issue.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
513 lines
20 KiB
Python
513 lines
20 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm_ascend.worker.pcp_utils import PCPManager
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none",
|
|
[
|
|
(1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False),
|
|
(1, 2, 3, [20, 30, 40], 1, False, 50, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, False, 100, True),
|
|
(2, 1, 4, [5, 10, 40, 60], 2, True, 100, True),
|
|
(2, 1, 3, [5, 10, 15], 3, False, 50, True),
|
|
(2, 1, 3, [40, 50, 60], 0, False, 150, True),
|
|
])
|
|
def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
|
num_decodes, use_mla, total_tokens,
|
|
expect_not_none):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.model_config.use_mla = use_mla
|
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = 64
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
|
|
num_computed_tokens = []
|
|
num_prompt_tokens = []
|
|
num_tokens = []
|
|
|
|
for i in range(num_reqs):
|
|
if i < num_decodes:
|
|
num_computed_tokens.append(query_lens[i])
|
|
num_prompt_tokens.append(query_lens[i] // 2)
|
|
num_tokens.append(query_lens[i])
|
|
else:
|
|
num_computed_tokens.append(0)
|
|
num_prompt_tokens.append(query_lens[i])
|
|
num_tokens.append(query_lens[i])
|
|
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
|
|
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
|
input_batch.num_tokens = torch.tensor(num_tokens)
|
|
num_scheduled_tokens = np.array(
|
|
query_lens) - input_batch.num_computed_tokens_cpu
|
|
|
|
query_lens = torch.tensor(query_lens)
|
|
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
|
input_batch,
|
|
num_scheduled_tokens)
|
|
|
|
if not expect_not_none:
|
|
assert result is None, f"Expected to return None, but got {type(result)}"
|
|
else:
|
|
assert result is not None, "Expected to return a metadata object, but got None."
|
|
|
|
assert hasattr(result, 'num_actual_tokens_pcp_padded')
|
|
assert hasattr(result, 'num_computed_tokens_of_pcp_dcp')
|
|
|
|
if pcp_size > 1:
|
|
assert hasattr(result, 'pcp_allgather_restore_idx')
|
|
|
|
has_prefill_requests = (num_reqs - num_decodes) > 0
|
|
if has_prefill_requests:
|
|
assert hasattr(result, 'q_head_idx_tensor')
|
|
assert hasattr(result, 'q_tail_idx_tensor')
|
|
assert hasattr(result, 'q_full_idx')
|
|
assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_head_mask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor')
|
|
assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor')
|
|
assert hasattr(result, 'attn_mask_seqlens')
|
|
assert hasattr(result, 'head_attn_nomask_seqlens')
|
|
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
|
[
|
|
# Case 1: prefill only
|
|
([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]),
|
|
|
|
# # Case 2: mix prefill and decode
|
|
([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]),
|
|
|
|
# # Case 3: request which need to be padded
|
|
([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]),
|
|
|
|
# Case 4: single request
|
|
([10], 1, [0], [10], 4, 0, [4]),
|
|
])
|
|
def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
|
num_prompt_tokens, pcp_size, pcp_rank,
|
|
expected_pcp_tokens):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
input_batch = MagicMock()
|
|
input_batch.num_reqs = num_reqs
|
|
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens,
|
|
dtype=np.int32)
|
|
input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32)
|
|
arange_np = np.arange(10000)
|
|
pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(
|
|
np.array(tokens), arange_np, num_reqs, 1)
|
|
|
|
assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \
|
|
f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}"
|
|
|
|
total_pcp_tokens: int = np.sum(pcp_tokens_result)
|
|
assert positions_result.shape == (total_pcp_tokens,), \
|
|
f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}"
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target",
|
|
[
|
|
# without pcp and dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 1, 1,
|
|
torch.tensor([[[1]], [[2]], [[128]], [[129]]])),
|
|
# pcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 1,
|
|
torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])),
|
|
# dcp
|
|
(torch.tensor([1, 2, 128, 129]), 1, 2, 1,
|
|
torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])),
|
|
# pcp + dcp
|
|
(torch.tensor([1, 2, 128, 129]), 2, 2, 1,
|
|
torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]],
|
|
[[32, 32], [32, 32]], [[33, 32], [32, 32]]])),
|
|
# specify interleave_size
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 2,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])),
|
|
(torch.tensor([1, 2, 128, 129]), 2, 1, 128,
|
|
torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])),
|
|
(torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128,
|
|
torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]],
|
|
[[128, 0], [0, 0]], [[128, 1], [0, 0]],
|
|
[[128, 128], [0, 0]], [[128, 128], [1, 0]]])),
|
|
]
|
|
)
|
|
# yapf: enable
|
|
def test_get_cp_local_seq_lens(
|
|
seq_lens,
|
|
pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size,
|
|
target,
|
|
):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 0
|
|
pcp_manager = PCPManager(pcp_world_size=pcp_world_size,
|
|
pcp_rank=0,
|
|
dcp_world_size=dcp_world_size,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=10000,
|
|
max_num_reqs=1000,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
|
|
dcp_world_size,
|
|
cp_kv_cache_interleave_size)
|
|
assert torch.equal(ret, target)
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize(
|
|
"req_ids, num_computed_tokens," \
|
|
"token_ids_tensor_list," \
|
|
"num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \
|
|
"target_input_ids_pcp_full, target_query_start_loc_pcp_full",
|
|
[
|
|
# prefill
|
|
(
|
|
['0'], np.array([0]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344])],
|
|
1, 6, {'0': 6},
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344]),
|
|
torch.tensor([0, 6])
|
|
),
|
|
# decode
|
|
(
|
|
['0'], np.array([6]),
|
|
[torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])],
|
|
1, 2, {'0': 2},
|
|
torch.tensor([88907, 0]),
|
|
torch.tensor([0, 2])
|
|
),
|
|
# decode + prefill
|
|
(
|
|
['0', '1'], np.array([6, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
],
|
|
2, 12, {'0': 2, '1': 10},
|
|
torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]),
|
|
torch.tensor([0, 2, 12])
|
|
),
|
|
# decodes + prefills
|
|
(
|
|
['0', '1', '2', '3'], np.array([6, 8, 0, 0]),
|
|
[
|
|
torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]),
|
|
torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]),
|
|
torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]),
|
|
torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
],
|
|
4, 19, {'0': 2, '1': 2, '2': 8, '3': 7},
|
|
torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907,
|
|
0, 671, 5335, 1469, 7539, 305, 6397]),
|
|
torch.tensor([0, 2, 4, 12, 19])
|
|
),
|
|
])
|
|
# yapf: enable
|
|
def test_generate_pcp_mtp_input(
|
|
req_ids,
|
|
num_computed_tokens,
|
|
token_ids_tensor_list,
|
|
num_reqs,
|
|
total_num_scheduled_tokens,
|
|
num_scheduled_tokens,
|
|
target_input_ids_pcp_full,
|
|
target_query_start_loc_pcp_full,
|
|
):
|
|
max_num_reqs = 4
|
|
max_model_len = 4096
|
|
max_num_tokens = 4096
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config = MagicMock()
|
|
vllm_config.speculative_config.num_speculative_tokens = 1
|
|
vllm_config.scheduler_config.max_num_seqs = max_num_reqs
|
|
vllm_config.scheduler_config.max_num_batched_tokens = max_model_len
|
|
pcp_manager = PCPManager(pcp_world_size=2,
|
|
pcp_rank=0,
|
|
dcp_world_size=1,
|
|
dcp_rank=0,
|
|
max_buffer_num_tokens=max_num_tokens,
|
|
max_num_reqs=max_num_reqs,
|
|
device="cpu",
|
|
vllm_config=vllm_config,
|
|
use_async_scheduling=False,
|
|
pin_memory=False)
|
|
arange_np = np.arange(max_model_len)
|
|
input_batch = MagicMock()
|
|
input_batch.num_computed_tokens_cpu = \
|
|
np.zeros(max_num_reqs, dtype=np.int32)
|
|
token_ids_cpu_tensor = torch.zeros(
|
|
(max_num_reqs, max_model_len),
|
|
device="cpu",
|
|
dtype=torch.int32,
|
|
)
|
|
input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor
|
|
input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy()
|
|
token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor
|
|
|
|
# Set input_batch
|
|
input_batch.req_ids = req_ids
|
|
input_batch.num_computed_tokens_cpu[:num_computed_tokens.
|
|
size] = num_computed_tokens
|
|
for i, token_ids_tensor in enumerate(token_ids_tensor_list):
|
|
token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor
|
|
|
|
pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens,
|
|
num_scheduled_tokens, False,
|
|
input_batch, arange_np)
|
|
assert torch.equal(
|
|
pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens],
|
|
target_input_ids_pcp_full)
|
|
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
|
target_query_start_loc_pcp_full)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"pcp_world_rank, split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs,"
|
|
"head_attn_nomask_seqlens, chunk_seqlens,"
|
|
"target_split_q_head, target_split_q_tail, target_head_seqlens, target_tail_seqlens",
|
|
[
|
|
# case1: pcp_world_rank=0
|
|
(0, [[10, 20, 30]], [[40, 50, 60]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32), [64], [
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
], [torch.tensor([40, 50, 60], dtype=torch.int32)], [
|
|
torch.tensor([[64], [0]], dtype=torch.int32)
|
|
], [torch.tensor([[64], [3]], dtype=torch.int32)]),
|
|
# case2: pcp_world_rank=1
|
|
(1, [[1, 2], [3, 4, 5]], [[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), [128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], [
|
|
torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)
|
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)]),
|
|
# case3: pcp_world_rank=2
|
|
(2, [[11, 12, 13, 14], [15, 16]], [[17, 18, 19], [20, 21, 22, 23]],
|
|
torch.tensor([[256, 256], [512, 512]], dtype=torch.int32), [256, 256],
|
|
[torch.tensor([11, 12, 13, 14, 15, 16], dtype=torch.int32)], [
|
|
torch.tensor([17, 18, 19, 20, 21, 22, 23], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [4, 2]], dtype=torch.int32)
|
|
], [torch.tensor([[256, 256], [3, 4]], dtype=torch.int32)]),
|
|
# case4: empty input
|
|
(
|
|
0,
|
|
[],
|
|
[],
|
|
torch.tensor([], dtype=torch.int32).reshape(2, 0),
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
# case5: single element input
|
|
(
|
|
0,
|
|
[[10]],
|
|
[[40]],
|
|
torch.tensor([[64], [0]], dtype=torch.int32),
|
|
[64],
|
|
[torch.tensor([1, 2, 3], dtype=torch.int32)],
|
|
[torch.tensor([40], dtype=torch.int32)],
|
|
[torch.tensor([[64], [0]], dtype=torch.int32)],
|
|
[torch.tensor([[64], [1]], dtype=torch.int32)],
|
|
),
|
|
# case6: pcp_world_rank=3
|
|
(
|
|
3,
|
|
[[1, 2], [3, 4, 5]],
|
|
[[6, 7], [8, 9, 10]],
|
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32),
|
|
[128, 128],
|
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)],
|
|
[torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
|
),
|
|
])
|
|
def test_split_nomask_idx_tensor_list(
|
|
pcp_world_rank, split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens,
|
|
chunk_seqlens, target_split_q_head, target_split_q_tail,
|
|
target_head_seqlens, target_tail_seqlens):
|
|
# Mock input data
|
|
mock_runner = MagicMock(spec=PCPManager)
|
|
mock_runner.device = "cpu"
|
|
mock_runner.pcp_world_rank = 0
|
|
mock_runner.kv_idx_names = {
|
|
"kv_with_q_head_nomask_idx_tensor":
|
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
}
|
|
|
|
mock_runner.pcp_world_rank = pcp_world_rank
|
|
|
|
# Mock output
|
|
mock_runner._split_multi_batch_kv_idx.side_effect = PCPManager._split_multi_batch_kv_idx.__get__(
|
|
mock_runner, PCPManager)
|
|
mock_runner._list_to_tensor.side_effect = PCPManager._list_to_tensor.__get__(
|
|
mock_runner, PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_nomask_idx_tensor_list(
|
|
mock_runner,
|
|
split_with_q_head_nomask_idx_reqs=split_with_q_head_nomask_idx_reqs,
|
|
split_kv_with_q_tail_nomask_idx_reqs=
|
|
split_kv_with_q_tail_nomask_idx_reqs,
|
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
|
chunk_seqlens=chunk_seqlens)
|
|
split_q_head, split_q_tail, head_seqlens, tail_seqlens = result
|
|
|
|
# Assert the method call
|
|
assert len(split_q_head) == len(target_split_q_head)
|
|
for res, target in zip(split_q_head, target_split_q_head):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(split_q_tail) == len(target_split_q_tail)
|
|
for res, target in zip(split_q_tail, target_split_q_tail):
|
|
assert torch.equal(res, target)
|
|
|
|
assert len(head_seqlens) == len(target_head_seqlens)
|
|
for res, target in zip(head_seqlens, target_head_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
assert len(tail_seqlens) == len(target_tail_seqlens)
|
|
for res, target in zip(tail_seqlens, target_tail_seqlens):
|
|
if isinstance(target, torch.Tensor):
|
|
assert torch.equal(res, target)
|
|
else:
|
|
assert res == target
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"kv_nomask_idx_multi_batch, split_size, expected_merged_idx, expected_merged_len",
|
|
[
|
|
# case1: multiple batches + split size greater than batch length
|
|
(
|
|
[[0, 1, 2, 3, 4], [5, 6, 7]],
|
|
2,
|
|
# expected merged_split_kv_idx_3d
|
|
[[0, 1, 5, 6], [2, 3, 7], [4]],
|
|
# expected merged_split_kv_len_2d
|
|
[[2, 2], [2, 1], [1, 0]],
|
|
),
|
|
# case2: single batch + split size greater than batch length
|
|
(
|
|
[[0, 1, 2]],
|
|
5,
|
|
[[0, 1, 2]],
|
|
[[3]],
|
|
),
|
|
# case3: split size equals maximum batch length
|
|
(
|
|
[[0, 1, 2, 3], [5, 6]],
|
|
4,
|
|
[[0, 1, 2, 3, 5, 6]],
|
|
[[4, 2]],
|
|
),
|
|
# case4: Split size is 1 (minimum granularity split)
|
|
(
|
|
[[0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[1, 1], [1, 0]],
|
|
),
|
|
# case6: the batch contains an empty list
|
|
(
|
|
[[], [0, 1], [2]],
|
|
1,
|
|
[[0, 2], [1]],
|
|
[[0, 1, 1], [0, 1, 0]],
|
|
),
|
|
# case7: empty input
|
|
(
|
|
[],
|
|
2,
|
|
[],
|
|
[],
|
|
),
|
|
])
|
|
def test_split_multi_batch_kv_idx(
|
|
kv_nomask_idx_multi_batch,
|
|
split_size,
|
|
expected_merged_idx,
|
|
expected_merged_len,
|
|
):
|
|
# Mock input data
|
|
model_runner = MagicMock(spec=PCPManager)
|
|
|
|
# Call the method under test
|
|
result = PCPManager._split_multi_batch_kv_idx(
|
|
self=model_runner,
|
|
kv_nomask_idx_multi_batch=kv_nomask_idx_multi_batch,
|
|
split_size=split_size)
|
|
|
|
merged_split_kv_idx_3d, merged_split_kv_len_2d = result
|
|
|
|
# Assert the method call
|
|
assert len(merged_split_kv_idx_3d) == len(expected_merged_idx)
|
|
|
|
for t, (actual_seg, expected_seg) in enumerate(
|
|
zip(merged_split_kv_idx_3d, expected_merged_idx)):
|
|
assert actual_seg == expected_seg
|
|
|
|
assert len(merged_split_kv_len_2d) == len(expected_merged_len)
|
|
|
|
for t, (actual_len, expected_len) in enumerate(
|
|
zip(merged_split_kv_len_2d, expected_merged_len)):
|
|
assert actual_len == expected_len
|