Files
xc-llm-ascend/tests/ut/worker/test_pooling_model_runner.py
wangyanhui-cmss 4e29c5a808 Add ut for test_pooling_model_runner.py (#1640)
### What this PR does / why we need it?
 Add ut for test_pooling_model_runner.py

### Does this PR introduce _any_ user-facing change? N/A

### How was this patch tested?
 python -m unittest  test_pooling_model_runner.py


- vLLM version: v0.9.1
- vLLM main:
2e610deb72

---------

Signed-off-by: wangyanhui-cmss <wangyanhui_yewu@cmss.chinamobile.com>
2025-07-07 17:12:11 +08:00

356 lines
15 KiB
Python

import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.engine.arg_utils import EngineArgs
from vllm.pooling_params import PoolingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm_ascend.worker.pooling_model_runner import (
ModelInputForNPUWithPoolingMetadata, NPUPoolingModelRunner)
class TestPoolingModelRunner(unittest.TestCase):
"""Unit tests for the NPUPoolingModelRunner class."""
def _create_model_runner(self, model: str, *args,
**kwargs) -> NPUPoolingModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = NPUPoolingModelRunner(vllm_config=engine_config, )
return model_runner
def setUp(self):
"""Initialize test fixtures and common mocks"""
self.attn_backend = "npu"
model_runner = self._create_model_runner(
"tests/ut/fake_weight",
trust_remote_code=True,
enable_chunked_prefill=False,
)
self.runner = model_runner
self.runner.attn_backend = self.attn_backend
model_runner.model = MagicMock()
self.runner = model_runner
# Sample test data
self.sample_tensor_dict = {"tensor1": torch.randn(3, 4)}
self.sample_seq_group = [MagicMock(spec=SequenceGroupMetadata)]
self.sample_finished_ids = ["req1", "req2"]
@patch(
'vllm_ascend.worker.pooling_model_runner.ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict'
)
def test_make_model_input_from_broadcasted_tensor_dict(
self, mock_from_dict):
"""Test tensor dictionary conversion to model input"""
# Setup mock return
expected_output = MagicMock()
mock_from_dict.return_value = expected_output
# Execute
result = self.runner.make_model_input_from_broadcasted_tensor_dict(
self.sample_tensor_dict)
# Verify
mock_from_dict.assert_called_once_with(self.sample_tensor_dict,
attn_backend=self.attn_backend)
self.assertEqual(result, expected_output)
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
def test_prepare_model_input_normal_case(self, mock_prepare_tensors,
mock_prepare_pooling):
"""Test normal flow of model input preparation"""
# Setup mocks
mock_model_input = ModelInputForNPUWithPoolingMetadata(
seq_lens=[1, 2, 3])
mock_prepare_tensors.return_value = mock_model_input
mock_pooling_metadata = MagicMock()
mock_prepare_pooling.return_value = mock_pooling_metadata
# Execute
result = self.runner.prepare_model_input(
seq_group_metadata_list=self.sample_seq_group,
finished_requests_ids=self.sample_finished_ids)
# Verify
mock_prepare_tensors.assert_called_once_with(self.sample_seq_group,
self.sample_finished_ids)
mock_prepare_pooling.assert_called_once_with(self.sample_seq_group,
mock_model_input.seq_lens)
self.assertEqual(result.pooling_metadata, mock_pooling_metadata)
def test_prepare_model_input_null_sequence_group(self):
"""Test assertion when seq_group_metadata_list is None"""
with self.assertRaises(AssertionError):
self.runner.prepare_model_input(
seq_group_metadata_list=None,
finished_requests_ids=self.sample_finished_ids)
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
def test_prepare_model_input_null_seq_lens(self, mock_prepare_tensors):
"""Test assertion when seq_lens is None in model input"""
# Setup mock with None seq_lens
mock_model_input = MagicMock()
mock_model_input.seq_lens = None
mock_prepare_tensors.return_value = mock_model_input
with self.assertRaises(AssertionError):
self.runner.prepare_model_input(
seq_group_metadata_list=self.sample_seq_group,
finished_requests_ids=self.sample_finished_ids)
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
def test_prepare_model_input_with_virtual_engine(self,
mock_prepare_tensors,
mock_prepare_pooling):
"""Test virtual engine parameter is properly handled"""
# Setup mocks
mock_model_input = ModelInputForNPUWithPoolingMetadata(
seq_lens=[1, 2, 3])
mock_prepare_tensors.return_value = mock_model_input
# Execute with virtual_engine parameter
result = self.runner.prepare_model_input(
seq_group_metadata_list=self.sample_seq_group,
virtual_engine=1,
finished_requests_ids=self.sample_finished_ids)
# Verify virtual_engine doesn't affect the flow
self.assertIsNotNone(result)
@patch.object(NPUPoolingModelRunner, '_prepare_pooling')
@patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors')
def test_prepare_model_input_with_null_finished_ids(
self, mock_prepare_tensors, mock_prepare_pooling):
"""Test case when finished_requests_ids is None"""
# Setup mocks
mock_model_input = ModelInputForNPUWithPoolingMetadata(
seq_lens=[1, 2, 3])
mock_prepare_tensors.return_value = mock_model_input
# Execute with None finished_ids
result = self.runner.prepare_model_input(
seq_group_metadata_list=self.sample_seq_group,
finished_requests_ids=None)
# Verify
mock_prepare_tensors.assert_called_once_with(self.sample_seq_group,
None)
self.assertIsNotNone(result)
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
def test_prepare_pooling_normal_case(self, mock_pooling_metadata):
"""Test normal case with multiple sequences in group"""
# Setup test data
mock_pooling_metadata.return_value = None
seq_data = {
1: MagicMock(spec=SequenceData),
2: MagicMock(spec=SequenceData)
}
pooling_params = MagicMock(spec=PoolingParams)
seq_group = MagicMock(spec=SequenceGroupMetadata)
seq_group.seq_data = seq_data
seq_group.pooling_params = pooling_params
# Call the function
self.runner._prepare_pooling([seq_group], [10, 20])
# Verify results
mock_pooling_metadata.assert_called_once_with(seq_groups=[
([1, 2], pooling_params)
],
seq_data=seq_data,
prompt_lens=[10, 20])
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
def test_prepare_pooling_empty_group(self, mock_pooling_metadata):
"""Test case with empty sequence group"""
# Setup empty group
mock_pooling_metadata.return_value = None
empty_seq_data: dict[int, SequenceData] = {}
pooling_params = MagicMock(spec=PoolingParams)
empty_group = MagicMock(spec=SequenceGroupMetadata)
empty_group.seq_data = empty_seq_data
empty_group.pooling_params = pooling_params
# Call the function
self.runner._prepare_pooling([empty_group], [])
# Verify results
mock_pooling_metadata.assert_called_once_with(seq_groups=[
([], pooling_params)
],
seq_data={},
prompt_lens=[])
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
def test_prepare_pooling_single_sequence(self, mock_pooling_metadata):
"""Test case with single sequence in group"""
# Setup single sequence
mock_pooling_metadata.return_value = None
single_seq_data = {3: MagicMock(spec=SequenceData)}
pooling_params = MagicMock(spec=PoolingParams)
single_group = MagicMock(spec=SequenceGroupMetadata)
single_group.seq_data = single_seq_data
single_group.pooling_params = pooling_params
# Call the function
self.runner._prepare_pooling([single_group], [5])
# Verify results
mock_pooling_metadata.assert_called_once_with(seq_groups=[
([3], pooling_params)
],
seq_data=single_seq_data,
prompt_lens=[5])
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
def test_prepare_pooling_multiple_groups(self, mock_pooling_metadata):
"""Test case with multiple sequence groups"""
# Setup multiple groups
mock_pooling_metadata.return_value = None
seq_data1 = {1: MagicMock(spec=SequenceData)}
seq_data2 = {2: MagicMock(spec=SequenceData)}
params1 = MagicMock(spec=PoolingParams)
params2 = MagicMock(spec=PoolingParams)
group1 = MagicMock(spec=SequenceGroupMetadata)
group1.seq_data = seq_data1
group1.pooling_params = params1
group2 = MagicMock(spec=SequenceGroupMetadata)
group2.seq_data = seq_data2
group2.pooling_params = params2
# Call the function
self.runner._prepare_pooling([group1, group2], [10, 20])
# Verify results
mock_pooling_metadata.assert_called_once_with(seq_groups=[
([1], params1), ([2], params2)
],
seq_data={
**seq_data1,
**seq_data2
},
prompt_lens=[10, 20])
@patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__')
def test_prepare_pooling_empty_input(self, mock_pooling_metadata):
"""Test case with empty input lists"""
# Call the function with empty inputs
mock_pooling_metadata.return_value = None
self.runner._prepare_pooling([], [])
# Verify results
mock_pooling_metadata.assert_called_once_with(seq_groups=[],
seq_data={},
prompt_lens=[])
@patch('vllm.forward_context.set_forward_context')
@patch('vllm.distributed.parallel_state._PP',
new_callable=lambda: MagicMock(spec=GroupCoordinator,
is_last_rank=True))
@patch('torch.npu.Event')
@patch.object(NPUPoolingModelRunner, 'set_active_loras')
@patch.object(NPUPoolingModelRunner, 'set_active_prompt_adapters')
def test_execute_model_normal_flow(self, mock_set_adapters, mock_set_loras,
mock_event, mock_pp, mock_set_forward):
"""Test normal execution path with all dependencies mocked"""
# Setup model input mock
mock_input = MagicMock()
mock_input.input_tokens = torch.tensor([1])
mock_input.input_positions = torch.tensor([0])
mock_input.multi_modal_kwargs = {}
self.runner.is_driver_worker = True
# Execute
self.runner.execute_model(model_input=mock_input,
kv_caches=[],
num_steps=1)
# Verify core calls
self.runner.model.pooler.assert_called_once()
@patch('vllm.forward_context.set_forward_context')
def test_execute_model_invalid_steps(self, mock_set_forward):
"""Test ValueError when num_steps != 1"""
with self.assertRaises(ValueError):
self.runner.execute_model(model_input=MagicMock(),
kv_caches=[],
num_steps=2)
mock_set_forward.assert_not_called()
@patch('vllm.forward_context.set_forward_context')
@patch('vllm.distributed.parallel_state._PP',
new_callable=lambda: MagicMock(spec=GroupCoordinator,
is_last_rank=False))
@patch('torch.npu.Event')
def test_execute_model_perf_monitoring(self, mock_event, mock_pp,
mock_set_forward):
"""Test performance monitoring with timing mocks"""
# Setup mocks
mock_event.return_value.elapsed_time.return_value = 15.0
self.runner.observability_config = MagicMock(
collect_model_forward_time=True)
# Execute
self.runner.execute_model(model_input=MagicMock(
input_tokens=torch.tensor([1]),
input_positions=torch.tensor([0]),
multi_modal_kwargs={}),
kv_caches=[],
num_steps=1)
# Verify timing calls
self.assertEqual(mock_event.call_count, 2)
@patch('vllm.forward_context.set_forward_context')
@patch.object(NPUPoolingModelRunner, 'set_active_loras')
@patch('vllm.distributed.parallel_state._PP',
new_callable=lambda: MagicMock(spec=GroupCoordinator,
is_last_rank=False))
def test_execute_model_lora_config(self, mock_pp, set_active_loras,
mock_set_forward):
"""Test LoRA configuration handling"""
# Setup
self.runner.lora_config = True
mock_input = MagicMock()
mock_input.lora_requests = ["req1"]
mock_input.lora_mapping = {"map": 1}
# Execute
self.runner.execute_model(model_input=mock_input,
kv_caches=[],
num_steps=1)
# Verify LoRA call
set_active_loras.assert_called_once_with(["req1"], {"map": 1})
@patch('vllm.forward_context.set_forward_context')
@patch('vllm.distributed.parallel_state._PP',
new_callable=lambda: MagicMock(spec=GroupCoordinator,
is_last_rank=False))
def test_execute_model_not_last_rank(self, mock_pp, mock_set_forward):
"""Test behavior when not the last pipeline rank"""
# Setup
# Execute
self.runner.execute_model(model_input=MagicMock(
input_tokens=torch.tensor([1]),
input_positions=torch.tensor([0]),
multi_modal_kwargs={}),
kv_caches=[],
num_steps=1)
# Verify pooler not called
self.runner.model.pooler.assert_not_called()