[1/N][refactor] torchair deepseek modeling refactor (#2384)
### What this PR does / why we need it?
Move torchair related model arch into torchair moduel to make the code
clear. Next step we'll remove all torchair related code outside of
torchair moduel.
### Does this PR introduce _any_ user-facing change?
No.
- vLLM version: v0.10.0
- vLLM main:
08d5f7113a
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair import utils
|
||||
@@ -26,3 +28,46 @@ class TestTorchairUtils(TestBase):
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_torchair_cache_dir_multiple_ranks(self):
|
||||
ranks = [0, 1, 2, 3]
|
||||
values = [100, 200, 300, 400]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
|
||||
for rank, expected in zip(ranks, values):
|
||||
self.assertEqual(expected,
|
||||
utils.read_kv_cache_bytes_from_file(rank))
|
||||
utils.delete_torchair_cache_file()
|
||||
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
@patch('vllm.ModelRegistry')
|
||||
def test_register_torchair_model(self, mock_model_registry):
|
||||
mock_registry = MagicMock()
|
||||
mock_model_registry.return_value = mock_registry
|
||||
utils.register_torchair_model()
|
||||
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 3)
|
||||
call_args_list = mock_model_registry.register_model.call_args_list
|
||||
|
||||
expected_registrations = [
|
||||
("DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
),
|
||||
("DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
),
|
||||
("DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
]
|
||||
|
||||
for i, (expected_name,
|
||||
expected_path) in enumerate(expected_registrations):
|
||||
args, kwargs = call_args_list[i]
|
||||
self.assertEqual(args[0], expected_name)
|
||||
self.assertEqual(args[1], expected_path)
|
||||
|
||||
Reference in New Issue
Block a user