import os from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock, patch from tests.ut.base import TestBase from vllm_ascend.torchair import utils class TestTorchairUtils(TestBase): def test_get_torchair_current_work_dir(self): cache_dir = utils.TORCHAIR_CACHE_DIR work_dir = utils._get_torchair_current_work_dir() self.assertEqual(cache_dir, work_dir) work_dir = utils._get_torchair_current_work_dir("test") self.assertEqual(os.path.join(cache_dir, "test"), work_dir) def test_torchair_cache_dir(self): utils.write_kv_cache_bytes_to_file(0, 100) self.assertTrue(utils.check_torchair_cache_exist(), "Create torchair cache dir failed") self.assertTrue(utils.check_kv_cache_bytes_cache_exist(), "Create kv cache bytes cache dir failed") kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0) self.assertEqual(100, kv_cache_bytes) utils.delete_torchair_cache_file() self.assertFalse(utils.check_torchair_cache_exist(), "Delete torchair cache dir failed") self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), "Delete kv cache bytes cache dir failed") def test_torchair_cache_dir_multiple_ranks(self): ranks = [0, 1, 2, 3] values = [100, 200, 300, 400] with ThreadPoolExecutor() as executor: executor.map(utils.write_kv_cache_bytes_to_file, ranks, values) for rank, expected in zip(ranks, values): self.assertEqual(expected, utils.read_kv_cache_bytes_from_file(rank)) utils.delete_torchair_cache_file() self.assertFalse(utils.check_torchair_cache_exist(), "Delete torchair cache dir failed") self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), "Delete kv cache bytes cache dir failed") @patch('vllm.ModelRegistry') def test_register_torchair_model(self, mock_model_registry): mock_registry = MagicMock() mock_model_registry.return_value = mock_registry utils.register_torchair_model() self.assertEqual(mock_model_registry.register_model.call_count, 3) call_args_list = mock_model_registry.register_model.call_args_list expected_registrations = [ ("DeepSeekMTPModel", "vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" ), ("DeepseekV2ForCausalLM", "vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" ), ("DeepseekV3ForCausalLM", "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" ) ] for i, (expected_name, expected_path) in enumerate(expected_registrations): args, kwargs = call_args_list[i] self.assertEqual(args[0], expected_name) self.assertEqual(args[1], expected_path)