### What this PR does / why we need it?
Move torchair related model arch into torchair moduel to make the code
clear. Next step we'll remove all torchair related code outside of
torchair moduel.
### Does this PR introduce _any_ user-facing change?
No.
- vLLM version: v0.10.0
- vLLM main:
08d5f7113a
Signed-off-by: linfeng-yuan <1102311262@qq.com>
74 lines
3.0 KiB
Python
74 lines
3.0 KiB
Python
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.torchair import utils
|
|
|
|
|
|
class TestTorchairUtils(TestBase):
|
|
|
|
def test_get_torchair_current_work_dir(self):
|
|
cache_dir = utils.TORCHAIR_CACHE_DIR
|
|
work_dir = utils._get_torchair_current_work_dir()
|
|
self.assertEqual(cache_dir, work_dir)
|
|
work_dir = utils._get_torchair_current_work_dir("test")
|
|
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
|
|
|
|
def test_torchair_cache_dir(self):
|
|
utils.write_kv_cache_bytes_to_file(0, 100)
|
|
self.assertTrue(utils.check_torchair_cache_exist(),
|
|
"Create torchair cache dir failed")
|
|
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
|
|
"Create kv cache bytes cache dir failed")
|
|
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
|
|
self.assertEqual(100, kv_cache_bytes)
|
|
utils.delete_torchair_cache_file()
|
|
self.assertFalse(utils.check_torchair_cache_exist(),
|
|
"Delete torchair cache dir failed")
|
|
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
|
"Delete kv cache bytes cache dir failed")
|
|
|
|
def test_torchair_cache_dir_multiple_ranks(self):
|
|
ranks = [0, 1, 2, 3]
|
|
values = [100, 200, 300, 400]
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
|
|
for rank, expected in zip(ranks, values):
|
|
self.assertEqual(expected,
|
|
utils.read_kv_cache_bytes_from_file(rank))
|
|
utils.delete_torchair_cache_file()
|
|
|
|
self.assertFalse(utils.check_torchair_cache_exist(),
|
|
"Delete torchair cache dir failed")
|
|
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
|
"Delete kv cache bytes cache dir failed")
|
|
|
|
@patch('vllm.ModelRegistry')
|
|
def test_register_torchair_model(self, mock_model_registry):
|
|
mock_registry = MagicMock()
|
|
mock_model_registry.return_value = mock_registry
|
|
utils.register_torchair_model()
|
|
|
|
self.assertEqual(mock_model_registry.register_model.call_count, 3)
|
|
call_args_list = mock_model_registry.register_model.call_args_list
|
|
|
|
expected_registrations = [
|
|
("DeepSeekMTPModel",
|
|
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
|
),
|
|
("DeepseekV2ForCausalLM",
|
|
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
|
),
|
|
("DeepseekV3ForCausalLM",
|
|
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
|
)
|
|
]
|
|
|
|
for i, (expected_name,
|
|
expected_path) in enumerate(expected_registrations):
|
|
args, kwargs = call_args_list[i]
|
|
self.assertEqual(args[0], expected_name)
|
|
self.assertEqual(args[1], expected_path)
|