From 07014e2101ce5bd2d9d3198f2e5fab9f2717975a Mon Sep 17 00:00:00 2001 From: zhangsicheng5 Date: Thu, 18 Dec 2025 10:54:57 +0800 Subject: [PATCH] [UT] Add model_runner pcp related UTs (#4951) 1. Add some uts for pcp related functions in NPUModelRunner - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: zhangsicheng5 Co-authored-by: wangxiyuan --- tests/ut/worker/test_model_runner_v1.py | 161 ++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 6d6ef054..facb8920 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -302,3 +302,164 @@ def test_update_tokens_for_pcp_unpad_mask(): actual_mask = unpad_mask.numpy().tolist() assert actual_mask == expected_mask, \ f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}" + + +# yapf: disable +@pytest.mark.parametrize( + "seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target", + [ + # without pcp and dcp + (torch.tensor([1, 2, 128, 129]), 1, 1, 1, + torch.tensor([[[1]], [[2]], [[128]], [[129]]])), + # pcp + (torch.tensor([1, 2, 128, 129]), 2, 1, 1, + torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])), + # dcp + (torch.tensor([1, 2, 128, 129]), 1, 2, 1, + torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])), + # pcp + dcp + (torch.tensor([1, 2, 128, 129]), 2, 2, 1, + torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]], + [[32, 32], [32, 32]], [[33, 32], [32, 32]]])), + # specify interleave_size + (torch.tensor([1, 2, 128, 129]), 2, 1, 2, + torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])), + (torch.tensor([1, 2, 128, 129]), 2, 1, 128, + torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])), + (torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128, + torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]], + [[128, 0], [0, 0]], [[128, 1], [0, 0]], + [[128, 128], [0, 0]], [[128, 128], [1, 0]]])), + ] +) +# yapf: enable +def test_get_cp_local_seq_lens( + seq_lens, + pcp_world_size, + dcp_world_size, + cp_kv_cache_interleave_size, + target, +): + mock_runner = MagicMock(spec=NPUModelRunner) + ret = NPUModelRunner._get_cp_local_seq_lens(mock_runner, seq_lens, + pcp_world_size, dcp_world_size, + cp_kv_cache_interleave_size) + assert torch.equal(ret, target) + + +@pytest.fixture +def pcp_mtp_mock_runner(): + # set up pcp & mtp related buffers + max_num_reqs = 4 + max_model_len = 4096 + max_num_tokens = 4096 + mock_runner = MagicMock(spec=NPUModelRunner) + mock_runner.device = 'cpu' + mock_runner.pin_memory = False + + # Init model_runner pcp_mtp related buffers + mock_runner.query_start_loc_pcp_full = NPUModelRunner._make_buffer( + mock_runner, max_num_reqs + 1, dtype=torch.int32) + + positions_buff = torch.zeros(max_num_tokens, + dtype=torch.int64, + device="cpu") + mock_runner.positions_pcp_full = positions_buff + mock_runner.positions_pcp_full_np = positions_buff.numpy() + + mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer( + mock_runner, max_num_tokens, dtype=torch.int32) + + mock_runner.arange_np = np.arange(max_model_len) + mock_runner.input_batch = MagicMock() + mock_runner.input_batch.num_computed_tokens_cpu = \ + np.zeros(max_num_reqs, dtype=np.int32) + token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + ) + mock_runner.input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor + mock_runner.input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy() + return mock_runner + + +# yapf: disable +@pytest.mark.parametrize( + "req_ids, num_computed_tokens," \ + "token_ids_tensor_list," \ + "num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \ + "target_input_ids_pcp_full, target_query_start_loc_pcp_full", + [ + # prefill + ( + ['0'], np.array([0]), + [torch.tensor([0, 671, 6102, 294, 8760, 344])], + 1, 6, {'0': 6}, + torch.tensor([0, 671, 6102, 294, 8760, 344]), + torch.tensor([0, 6]) + ), + # decode + ( + ['0'], np.array([6]), + [torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])], + 1, 2, {'0': 2}, + torch.tensor([88907, 0]), + torch.tensor([0, 2]) + ), + # decode + prefill + ( + ['0', '1'], np.array([6, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + ], + 2, 12, {'0': 2, '1': 10}, + torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + torch.tensor([0, 2, 12]) + ), + # decodes + prefills + ( + ['0', '1', '2', '3'], np.array([6, 8, 0, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]), + torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]), + torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]), + ], + 4, 19, {'0': 2, '1': 2, '2': 8, '3': 7}, + torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907, + 0, 671, 5335, 1469, 7539, 305, 6397]), + torch.tensor([0, 2, 4, 12, 19]) + ), + ]) +# yapf: enable +def test_generate_pcp_mtp_input( + pcp_mtp_mock_runner, + req_ids, + num_computed_tokens, + token_ids_tensor_list, + num_reqs, + total_num_scheduled_tokens, + num_scheduled_tokens, + target_input_ids_pcp_full, + target_query_start_loc_pcp_full, +): + mock_runner = pcp_mtp_mock_runner + token_ids_cpu_tensor = mock_runner.input_batch.token_ids_cpu_tensor + + # Set input_batch + mock_runner.input_batch.req_ids = req_ids + mock_runner.input_batch.num_computed_tokens_cpu[:num_computed_tokens. + size] = num_computed_tokens + for i, token_ids_tensor in enumerate(token_ids_tensor_list): + token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor + + NPUModelRunner._generate_pcp_mtp_input(mock_runner, num_reqs, + total_num_scheduled_tokens, + num_scheduled_tokens) + assert torch.equal( + mock_runner.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], + target_input_ids_pcp_full) + assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1], + target_query_start_loc_pcp_full)