[Test] Add ut test for torchair (#4287)
### What this PR does / why we need it?
The current community lacks unit tests (UT) for files such as
torchair_worker, mtp_proposer, and model_runner. Therefore, UT coverage
for these files needs to be added.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: CodeNine-CJ <chenjian343@huawei.com>
This commit is contained in:
85
tests/ut/torchair/test_torchair_mtp_proposer.py
Normal file
85
tests/ut/torchair/test_torchair_mtp_proposer.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class TestTorchairMtpProposer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
|
||||
vllm_config = MagicMock(spec=VllmConfig)
|
||||
vllm_config.device_config = MagicMock()
|
||||
vllm_config.device_config.device = torch.device("cpu")
|
||||
vllm_config.speculative_config = MagicMock()
|
||||
vllm_config.speculative_config.draft_model_config = MagicMock()
|
||||
vllm_config.speculative_config.draft_model_config.dtype = torch.float16
|
||||
vllm_config.speculative_config.method = "deepseek_mtp"
|
||||
vllm_config.speculative_config.num_speculative_tokens = 5
|
||||
vllm_config.load_config = MagicMock()
|
||||
cache_config = CacheConfig(block_size=16)
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024,
|
||||
max_num_seqs=64)
|
||||
|
||||
device = torch.device("cpu")
|
||||
runner = MagicMock()
|
||||
runner.pcp_size = 1
|
||||
runner.dcp_size = 1
|
||||
runner.pcp_rank = 0
|
||||
runner.max_num_tokens = 1024
|
||||
runner.max_num_reqs = 10
|
||||
runner._use_aclgraph.return_value = True
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__",
|
||||
return_value=None)
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
mock_set_default_dtype = mocker.patch(
|
||||
'vllm.model_executor.model_loader.utils.set_default_torch_dtype'
|
||||
)
|
||||
else:
|
||||
mock_set_default_dtype = mocker.patch(
|
||||
'vllm.utils.torch_utils.set_default_torch_dtype')
|
||||
mock_set_default_dtype.return_value.__enter__.return_value = None
|
||||
|
||||
mock_model_loader = MagicMock()
|
||||
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
|
||||
return_value=mock_model_loader)
|
||||
mock_layers = {
|
||||
"target_attn_layer_1": Mock(),
|
||||
"draft_attn_layer_2": Mock()
|
||||
}
|
||||
mocker.patch("vllm.config.get_layers_from_vllm_config",
|
||||
return_value=mock_layers)
|
||||
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
|
||||
mock_set_current.return_value.__enter__.return_value = None
|
||||
mock_torchair_deepseek_mtp = MagicMock()
|
||||
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
|
||||
return_value=mock_torchair_deepseek_mtp)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
|
||||
)
|
||||
|
||||
proposer = TorchairMtpProposer(vllm_config, device, runner)
|
||||
proposer.vllm_config = vllm_config
|
||||
proposer.device = device
|
||||
proposer.runner = runner
|
||||
proposer.speculative_config = vllm_config.speculative_config
|
||||
proposer.draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
proposer.method = vllm_config.speculative_config.method
|
||||
|
||||
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
|
||||
|
||||
def test_init(self, setup_torchair_mtp_proposer):
|
||||
proposer, _, _, = setup_torchair_mtp_proposer
|
||||
assert isinstance(proposer, TorchairMtpProposer)
|
||||
Reference in New Issue
Block a user