Drop torchair (#4814)

aclgraph is stable and fast now. Let's drop torchair graph mode now.

TODO: some logic to adapt torchair should be cleaned up as well. We'll
do it in the following PR.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-12-10 09:20:40 +08:00
committed by GitHub
parent ba9cda9dfd
commit 835b4c8f1d
84 changed files with 77 additions and 16881 deletions

View File

@@ -118,7 +118,6 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None
)), \

View File

@@ -110,11 +110,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled,
mock_soc_version, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
@@ -139,9 +134,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
@@ -165,9 +157,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
@@ -190,9 +179,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
@@ -219,9 +205,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# test case when rotary_dim < head_size
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2
@@ -415,7 +398,6 @@ class TestAscendMRotaryEmbedding(unittest.TestCase):
mrope_section=self.mrope_section)
self.mock_config = MagicMock()
self.mock_config.torchair_graph_config.enabled = False
def _create_vllm_config(self):
vllm_config = VllmConfig()

View File

@@ -33,7 +33,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
mock_get_ep_group.return_value = mock_ep_group
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_ascend_config.enable_chunked_prefill = False
mock_get_ascend_config.return_value = mock_ascend_config
mock_mc2_group = Mock(device_group=0)

View File

@@ -13,15 +13,10 @@
# This file is a part of the vllm-ascend project.
#
import os
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.config import VllmConfig
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import (_check_torchair_supported,
check_ascend_config,
clear_ascend_config, get_ascend_config,
from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config,
init_ascend_config)
@@ -45,17 +40,6 @@ class TestAscendConfig(TestBase):
self.assertIsNone(ascend_config.expert_map_path)
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config
self.assertFalse(torchair_graph_config.enabled)
self.assertEqual(torchair_graph_config.mode, '')
self.assertFalse(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertFalse(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertFalse(torchair_graph_config.enable_kv_nz)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
@@ -63,16 +47,6 @@ class TestAscendConfig(TestBase):
def test_init_ascend_config_with_additional_config(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4],
"graph_batch_sizes_init": False,
"enable_multistream_mla": True,
"enable_view_optimize": True,
"enable_frozen_parameter": True,
"enable_kv_nz": True
},
"ascend_compilation_config": {
"enable_quantization_fusion": False,
},
@@ -84,65 +58,9 @@ class TestAscendConfig(TestBase):
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config
self.assertTrue(torchair_graph_config.enabled)
self.assertTrue(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertTrue(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertTrue(torchair_graph_config.enable_kv_nz)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
@_clean_up_ascend_config
def test_init_ascend_config_with_refresh(self):
test_vllm_config = VllmConfig()
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertTrue(ascend_config.torchair_graph_config.enabled)
@_clean_up_ascend_config
def test_init_ascend_config_with_wrong_input(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"graph_batch_sizes": "fake_size",
},
"refresh": True,
}
with self.assertRaises(TypeError):
init_ascend_config(test_vllm_config)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": True,
},
"refresh": True,
}
with self.assertRaises(ValueError):
init_ascend_config(test_vllm_config)
@_clean_up_ascend_config
def test_get_ascend_config(self):
test_vllm_config = VllmConfig()
@@ -162,203 +80,3 @@ class TestAscendConfig(TestBase):
clear_ascend_config()
with self.assertRaises(RuntimeError):
get_ascend_config()
@_clean_up_ascend_config
def test_check_ascend_config_pass(self):
test_vllm_config = VllmConfig()
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
@_clean_up_ascend_config
def test_check_ascend_config_wrong_case(self):
test_vllm_config = VllmConfig()
# torchair + eager mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
enforce_eager = True
check_ascend_config(test_vllm_config, enforce_eager)
# torchair + non deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "llama"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
('qwen', True), ('llama', False)]
for model_type, expected_output in test_cases:
self.assertEqual(_check_torchair_supported(model_type),
expected_output)
@_clean_up_ascend_config
def test_ascend_config_load_error(self):
test_vllm_config = VllmConfig()
# graph_batch_sizes should be list.
with self.assertRaises(TypeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"graph_batch_sizes": "fake_size",
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_graph should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_graph": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes should not be set without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4],
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes_init should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes_init": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_multistream_mla should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_multistream_mla": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# mode should not be configured without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"mode": 'max-autotune',
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_kv_nz should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_kv_nz": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"lmhead_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"oproj_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"oproj_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=1)
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
test_vllm_config.model_config = ModelConfig(model=model_path,
enforce_eager=True)
init_ascend_config(test_vllm_config)

View File

@@ -31,7 +31,6 @@ class TestNPUPlatform(TestBase):
@staticmethod
def mock_vllm_ascend_config():
mock_ascend_config = MagicMock()
mock_ascend_config.torchair_graph_config.enabled = False
mock_ascend_config.xlite_graph_config.enabled = False
mock_ascend_config.enable_shared_expert_dp = False
return mock_ascend_config
@@ -403,47 +402,6 @@ class TestNPUPlatform(TestBase):
CUDAGraphMode.NONE,
)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch("vllm_ascend.utils.update_default_aclgraph_sizes")
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch(
"vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config"
)
def test_check_and_update_config_torchair_enabled_compilation(
self, mock_init_recompute, mock_init_ascend, mock_check_ascend,
mock_update_default, mock_soc_version):
mock_update_default.return_value = MagicMock()
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.parallel_config.decode_context_parallel_size = 1
vllm_config.parallel_config.prefill_context_parallel_size = 1
vllm_config.parallel_config.tensor_parallel_size = 1
mock_init_recompute.return_value = MagicMock()
vllm_config.scheduler_config = MagicMock()
vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("Torchair compilation enabled" in cm.output[0])
self.assertEqual(
vllm_config.compilation_config.mode,
CompilationMode.NONE,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@@ -503,16 +461,6 @@ class TestNPUPlatform(TestBase):
"vllm_ascend.worker.worker_v1.NPUWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
vllm_config.parallel_config.worker_cls = "auto"
self.platform.check_and_update_config(vllm_config)
self.assertEqual(
vllm_config.parallel_config.worker_cls,
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.xlite_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
@@ -550,14 +498,7 @@ class TestNPUPlatform(TestBase):
self.platform.check_and_update_config(vllm_config)
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_config.enable_shared_expert_dp = False
mock_get_ascend_config.return_value = mock_config
def test_get_attn_backend_cls_use_v1_and_mla(self):
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
@@ -570,56 +511,7 @@ class TestNPUPlatform(TestBase):
self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
#use_sfa=False,
use_mla=True,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
#use_sfa=False,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
def test_get_attn_backend_cls_use_v1_only(self):
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,

View File

@@ -1,61 +0,0 @@
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import PytestBase
from vllm_ascend.torchair.models.qwen3_moe import CustomSparseMoeBlock
class TestCustomSparseMoeBlock(PytestBase):
@pytest.fixture
def setup_csmb(self, mocker: MockerFixture):
config = PretrainedConfig(num_experts=64,
hidden_size=2048,
num_experts_per_tok=2,
moe_intermediate_size=1408,
norm_topk_prob=True)
mocker.patch(
'vllm_ascend.torchair.models.qwen3_moe.get_tensor_model_parallel_world_size',
return_value=10)
mocker.patch(
'vllm.model_executor.layers.linear.ReplicatedLinear.__init__',
return_value=None)
mocker.patch(
'vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__init__',
return_value=None)
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_tp_group',
return_value=tp_group)
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_dp_group',
return_value=dp_group)
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_ep_group',
return_value=ep_group)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
custom_moe_block = CustomSparseMoeBlock(config, None, "")
return custom_moe_block
def test_init(self, mocker: MockerFixture, setup_csmb):
custom_moe_block = setup_csmb
assert isinstance(custom_moe_block, CustomSparseMoeBlock)

View File

@@ -1,206 +0,0 @@
import pytest
import torch
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
TorchairDeepSeekMultiTokenPredictorLayer)
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
@pytest.fixture
def setup_mtp_layer(self, mocker: MockerFixture):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
return_value=1)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
return_value=None)
mocker.patch(
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
return_value=None)
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
return mtp_layer
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch.object(mtp_layer,
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
kv_cache = torch.randn(2, 3, 768)
previous_hidden_states = torch.randn(2, 3, 768)
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (3, 768)
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
@pytest.fixture
def setup_predictor(self, mocker: MockerFixture):
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_model_config = mocker.MagicMock(spec=ModelConfig)
mock_hf_config = mocker.MagicMock()
mock_hf_config.num_hidden_layers = 12
mock_hf_config.num_nextn_predict_layers = 3
mock_hf_config.vocab_size = 30000
mock_model_config.hf_config = mock_hf_config
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = CacheConfig()
mock_vllm_config.quant_config = mocker.MagicMock()
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
predictor = TorchairDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
return predictor
def test_init(self, mocker: MockerFixture, setup_predictor):
predictor = setup_predictor
assert predictor.num_mtp_layers == 3
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
@pytest.mark.parametrize(
'kv_caches, inputs_embeds',
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
inputs_embeds):
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
# todo: need or not?
# predictor.num_mtp_layers = 1
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
output = predictor.forward(input_ids, positions, kv_caches, None, None,
inputs_embeds, 0)
mock_layer.assert_called_once()
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states)
predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
class TestTorchairDeepSeekMTP(PytestBase):
@pytest.fixture
def setup_mtp(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.model_config.hf_config.num_hidden_layers = 12
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
vllm_config.cache_config = mocker.MagicMock()
vllm_config.quant_config = mocker.MagicMock()
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
ascend_config = mocker.MagicMock()
ascend_config.max_num_batched_tokens = 2048
ascend_config.max_model_len = 1024
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=ascend_config)
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
return mtp
def test_init(self, mocker: MockerFixture, setup_mtp):
mtp = setup_mtp
assert isinstance(mtp, TorchairDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
spec_step_idx = 0
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

View File

@@ -1,366 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.transformers_utils.config import patch_rope_parameters
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
TorchairDeepseekV2RowParallelLinear,
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2SiluAndMul)
@pytest.fixture
def base_config():
config = PretrainedConfig(
hidden_size=128,
num_attention_heads=8,
num_hidden_layers=2,
intermediate_size=256,
hidden_act="silu",
rms_norm_eps=1e-6,
rope_theta=10000.0,
max_position_embeddings=2048,
n_routed_experts=4,
n_shared_experts=1,
moe_intermediate_size=256,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
first_k_dense_replace=0,
moe_layer_freq=1,
kv_lora_rank=16,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
topk_method="noaux_tc",
scoring_func="softmax",
norm_topk_prob=True,
n_group=1,
topk_group=1,
vocab_size=10000,
)
patch_rope_parameters(config)
return config
@pytest.fixture
def vllm_config(base_config):
model_config = SimpleNamespace(
hf_config=base_config,
tensor_parallel_size=1,
dtype=torch.float32,
use_mla=False,
quant_config=None,
max_model_len=2048,
)
cache_config = CacheConfig()
vllm_config = Mock()
vllm_config.model_config = model_config
vllm_config.cache_config = cache_config
vllm_config.quant_config = None
return vllm_config
@pytest.fixture
def mock_distributed():
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
pp_group = Mock(spec=GroupCoordinator)
pp_group.rank_in_group = 0
pp_group.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mlp_tp_group = Mock(spec=GroupCoordinator)
mlp_tp_group.rank_in_group = 0
mlp_tp_group.world_size = 1
mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \
patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \
patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
return_value=forward_context):
yield
@pytest.fixture
def patch_attention_init():
try:
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
DeepseekV2Attention
original_init = DeepseekV2Attention.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("decoder_layer", None)
if 'vllm_config' not in kwargs:
mock_vllm_config = Mock()
mock_vllm_config.model_config = Mock()
mock_vllm_config.model_config.hf_config = Mock()
mock_vllm_config.model_config.hf_config.hidden_size = 128
mock_vllm_config.model_config.dtype = torch.float32
mock_vllm_config.model_config.quant_config = None
mock_vllm_config.cache_config = CacheConfig()
kwargs['vllm_config'] = mock_vllm_config
return original_init(self, *args, **kwargs)
DeepseekV2Attention.__init__ = patched_init
yield
DeepseekV2Attention.__init__ = original_init
except ImportError:
yield
def test_torchair_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu")
silu = TorchairDeepseekV2SiluAndMul()
assert silu.weight_scale is None
x = torch.randn(2, 4)
output = silu.forward_oot(x)
assert output.shape == (2, 2)
weight_scale = Mock(return_value=torch.tensor(0.1))
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
dynamic_scale = torch.randn(2, 1)
with patch("torch_npu.npu_dequant_swiglu_quant",
return_value=torch.randn(2, 4)):
output = silu.forward_oot((quant_x, dynamic_scale))
assert output.shape == (2, 4)
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
output_sizes=[64, 64],
bias=False,
quant_config=None)
assert linear.output_sizes == [64, 64]
param = Mock()
param.data = torch.zeros(128, 128)
param.output_dim = 1
param.is_gguf_weight = False
param.is_gguf_weight_type = False
loaded_weight = torch.randn(128, 64)
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
with pytest.raises(AssertionError):
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
@pytest.mark.parametrize("cls", [
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2RowParallelLinear
])
def test_row_parallel_linear(cls, mock_distributed, mock_forward_context):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
input_ = torch.randn(2, 4, 128)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
linear.input_is_parallel = False
output = linear(input_, is_prefill=True)
assert output[0].shape == (2, 4, 64)
linear.input_is_parallel = True
output = linear(input_, is_prefill=False)
assert output[0].shape == (2, 4, 64)
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
mlp = TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=None)
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
) as mock_quant_config:
mock_quant_config.name = "w8a8dynamic"
with pytest.raises(NotImplementedError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=mock_quant_config,
force_replicate=False)
with pytest.raises(ValueError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="relu",
quant_config=None)
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = TorchairDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=16,
kv_lora_rank=16,
cache_config=CacheConfig(),
quant_config=None,
prefix="layers.0.self_attn")
assert attn.debug_layer_idx == 0
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=None,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config, mock_forward_context,
patch_attention_init):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128))
base_config.n_routed_experts = 4
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
cache_config=CacheConfig(),
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
hidden_states, residual = layer(positions, x, None)
assert hidden_states.shape == (2, 4, 128)
base_config.n_routed_experts = None
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
patch_attention_init):
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4))
positions = torch.arange(4).repeat(2, 1)
with patch.object(model.model,
"forward",
return_value=torch.randn(2, 4, 128)):
output = model(input_ids, positions)
assert output.shape == (2, 4, 128)
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
with patch(
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
assert loaded is not None

View File

@@ -1,423 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend
from vllm_ascend.ascend_forward_context import get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import (
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import AscendDeviceType
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
# init dist env patch
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
with patch('torch.npu.is_available', return_value=True), \
patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
return_value=torch.randn(5, 32)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
enable_shared_expert_dp=False,
expert_map_path=None,
init_redundancy_expert=2,
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
return_value=MagicMock(
max_tokens_across_dp=10,
dp_metadata=dp_metadata,
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)):
yield
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
# init moe env patch
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
"""default moe config"""
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
moe.moe_parallel_config.use_ep = False
moe.moe_parallel_config.dp_size = 1
return TorchairAscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestTorchairAscendFusedMoe:
@pytest.fixture
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = TorchairAscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
# check group_topk
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = TorchairAscendFusedMoE(**error_config)
# check scoring_func
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = TorchairAscendFusedMoE(**error_config)
@pytest.fixture
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = False
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
@pytest.fixture
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = True
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method,
TorchairAscendUnquantizedFusedMoEMethod)
@pytest.fixture
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
"""
1 test has shared_experts
2 test has top_k
3 test is_prefill is true
4 test single num_tokens(decode)
5 test ep_size is 1 and is_prefill is true
6 test ep_size is 1 and is_prefill is False
"""
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
@pytest.fixture
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestTorchairAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size = others_param
is_prefill = False
global_redundant_expert_num = vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config(
).init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
if ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test use_select_experts and use fused_expters_with_mc2
2 test use_select_experts and fused_experts_with_all2all_buffer
3 test use_select_experts and fused_experts_with_all2all
4 test use_select_experts and fused_experts
"""
ep_size = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_device_type", return_value=AscendDeviceType._910_93):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape

View File

@@ -1,333 +0,0 @@
import math
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
_set_cos_sin_cache, custom_rotary_embedding_enabled,
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
yarn_find_correction_dim, yarn_get_mscale)
from vllm_ascend.utils import AscendDeviceType
class TestCustomRotaryEmbeddingEnabled(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=False):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestRopeForwardOot(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
mock_get_ascend_config):
# Setup mock for torchair enabled
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.mock_self.forward_native.assert_called_once_with(
self.positions, self.query, self.key, None)
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C_ascend')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._910_93)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled,
mock_soc_version,
mock_get_ascend_config, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
non_contig_query, non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
rope_forward_oot(self.mock_self, self.positions, self.query,
self.key, offsets)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = rope_forward_oot(self.mock_self,
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
self.beta_fast = 32
self.beta_slow = 1
self.max_position_embeddings = 4096
self.mscale = 1.0
self.scaling_factor = 40
def register_buffer(self):
pass
class TestSetSinCosCache(TestBase):
def test_set_cos_sin_cache(self):
module = MockRopeModule()
with patch.object(module, "register_buffer") as mock_register_buffer:
_set_cos_sin_cache(module,
1024,
device="cpu",
dtype=torch.bfloat16)
mock_register_buffer.assert_called()
class TestNativeRopeDeepseekForward(TestBase):
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == (1, 128)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_rope_forward_oot):
module = MockRopeModule(is_neox_style=False)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
class TestRotateHalf(TestBase):
def test_rotate_half_even_dim(self):
# Test with even dimension
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
result = rotate_half(x)
self.assertTrue(torch.allclose(result, expected))
class TestYarnFindCorrectionDim(TestBase):
def test_basic_case(self):
# Test with standard values
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
class TestYarnGetMscale(TestBase):
def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
def test_scale_greater_than_1(self):
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -1,296 +0,0 @@
from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
TorchairAscendW4A8DynamicFusedMoEMethod,
TorchairAscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_get_current_vllm_config.return_value = mock_vllm_config
self.method = TorchairAscendW4A8DynamicLinearMethod()
self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
# new quant version weight
self.method.new_quant_version = True
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (16, 8))
self.assertEqual(weight["_packed_dim"], 0)
self.assertEqual(weight["_packed_factor"], 2)
def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale"].shape, (32, 1))
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].shape, (32, 1))
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
# new quant version weight
self.method.new_quant_version = True
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="column")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 1))
params = self.method.get_pergroup_param(8,
32,
torch.bfloat16,
layer_type="row")
self.assertEqual(params["scale_bias"].dtype, torch.float32)
self.assertEqual(params["scale_bias"].shape, (32, 16))
@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
mock_npu_convert_weight.return_value = torch.zeros((32, 4),
dtype=torch.int32)
# old quant version weight
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(torch.zeros((32, 8),
dtype=torch.int8),
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset = torch.nn.Parameter(torch.empty_like(
layer.weight_scale.data),
requires_grad=False)
layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.method.new_quant_version = True
new_layer = torch.nn.Module()
new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8),
dtype=torch.int8),
requires_grad=False)
new_layer.weight_scale = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset = torch.nn.Parameter(torch.empty_like(
new_layer.weight_scale.data),
requires_grad=False)
new_layer.weight_scale_second = torch.nn.Parameter(torch.ones(
(32, 1), dtype=torch.float32),
requires_grad=False)
new_layer.weight_offset_second = torch.nn.Parameter(
torch.empty_like(new_layer.weight_scale_second.data),
requires_grad=False)
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))
self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32))
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
)
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
def test_get_weight(self):
# old quant version w4a8 weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.float32)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.float32)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
# per-channel weight
self.quant_method.is_per_channel_weight = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
pergroup_param = [
"w13_weight_scale_second", "w13_weight_offset_second",
"w2_weight_scale_second", "w2_weight_offset_second"
]
is_contains = any(key in param_dict for key in pergroup_param)
self.assertFalse(is_contains)
def build_layer(self,
is_new_quant_version=True,
is_per_channel_weight=False):
layer = torch.nn.Module()
if is_new_quant_version:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8),
requires_grad=False)
w13_scale_bias = torch.zeros(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros((self.experts, self.output_size,
16 // self.quant_method.tp_size),
dtype=torch.float32)
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.float32),
requires_grad=False)
if not is_per_channel_weight:
layer.w13_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w13_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w13_weight_scale_second.data),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w2_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w2_weight_scale_second.data),
requires_grad=False)
return layer
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
# old quant version weight
layer = self.build_layer(is_new_quant_version=False)
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer = self.build_layer(is_new_quant_version=True)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
# per-channel weight
self.quant_method.is_per_channel_weight = True
per_channel_layer = self.build_layer(is_new_quant_version=True,
is_per_channel_weight=True)
self.quant_method.process_weights_after_loading(per_channel_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))

View File

@@ -1,129 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
from vllm_ascend.utils import AscendDeviceType
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing_quant")
def test_torchair_fused_experts_with_all2all(
self, mock_npu_moe_init_routing_quant, mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
mock_moe_re_routing, mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_npu_moe_init_routing_quant.return_value = (
placeholder_int8, placeholder_ones, placeholder_ones,
torch.bincount(placeholder_ones, minlength=len(expert_map)),
torch.randn(self.num_tokens))
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
result = torchair_fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))
@patch.dict('os.environ', {
'HCCL_INTRA_ROCE_ENABLE': '0',
'HCCL_INTRA_PCIE_ENABLE': '1'
})
@patch(
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_device_type"
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
)
@patch('torch_npu.npu_moe_distribute_combine_v2')
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
@patch(
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
)
def test_torchair_fused_experts_with_mc2_a2_optimization(
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
mock_ascend_soc_version):
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
# Setup mocks
mock_ascend_soc_version.return_value = AscendDeviceType._910B
mock_group = MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 4
mock_get_group.return_value = mock_group
mock_combine.return_value = self.placeholder
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
torch.randint(0, 32, (32, )),
torch.randint(1, 5, (8, )),
torch.randint(1, 5, (4, )), None,
torch.randn(32))
mock_mlp_decode.return_value = self.placeholder
result = torchair_fused_experts_with_mc2(
hidden_states=self.placeholder,
w1=self.placeholder,
w2=self.placeholder,
w1_scale=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=2,
mc2_mask=self.placeholder)
# Check that expert_scales was passed to dispatch
call_args = mock_dispatch.call_args[1]
self.assertIn('expert_scales', call_args)
self.assertIsInstance(result, torch.Tensor)
self.assertEqual(result.shape, self.placeholder.shape)

View File

@@ -1,95 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.torchair_attention import \
AscendAttentionTorchairBackendImpl
class TestAscendAttentionTorchairBackendImpl(TestBase):
@patch("torch.zeros")
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2) # TODO
@patch("vllm.config.get_current_vllm_config") # TODO
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp,
mock_zeros):
mock_tp.world_size = 2 # TODO
ascend_config.torchair_graph_config.enabled = True # TODO
ascend_config.torchair_graph_config.enable_kv_nz = False # TODO
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
num_heads = 32
head_size = 128 # TODO
scale = 0.1 # TODO
num_kv_heads = 4
kv_cache_dtype = "auto"
attn_type = AttentionType.DECODER
mock_zeros.return_value = torch.ones((),
device='cpu',
dtype=torch.int32)
self.impl = AscendAttentionTorchairBackendImpl(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=attn_type,
kv_sharing_target_layer_name=None)
@patch("torch_npu.npu_scatter_nd_update_")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_with_decode_only(self, mock_fused, _):
layer = MagicMock()
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
seq_len = 1
num_tokens = 100
num_blocks = 256
block_size = 4
query = torch.randn(num_tokens, seq_len,
self.impl.num_heads * self.impl.head_size)
key = torch.randn(num_tokens, seq_len,
self.impl.num_kv_heads * self.impl.head_size)
value = torch.randn(num_tokens, seq_len,
self.impl.num_kv_heads * self.impl.head_size)
kv_cache = (torch.randn(num_blocks, block_size,
self.impl.num_heads * self.impl.head_size),
torch.randn(num_blocks, block_size,
self.impl.num_heads * self.impl.head_size))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.head_size)
decode = MagicMock() # TODO
decode.seq_lens_list = [2] * num_tokens
decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32)
decode.attn_mask = None
metadata = MagicMock()
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32)
metadata.decode = decode
mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads,
self.impl.head_size),
torch.ones(1))
result = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -1,887 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch import nn
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.torchair_mla import (
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
AscendMLATorchairImpl, AscendMLATorchairMetadata,
AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
class TestAscendMLATorchairBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendMLATorchairBackend.get_name(),
"ASCEND_MLA_TORCHAIR")
def test_get_builder_cls(self):
self.assertEqual(AscendMLATorchairBackend.get_builder_cls(),
AscendMLATorchairMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendMLATorchairBackend.get_impl_cls()
self.assertEqual(result, AscendMLATorchairImpl)
class TestAscendMLATorchairPrefillMetadata(TestBase):
def test_ascend_mla_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase):
def test_ascend_mla_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendMLATorchairDecodeMetadata(input_positions,
block_table, seq_lens,
max_seq_lens, seq_lens_list,
attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendMLATorchairMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendMLATorchairMetadata(
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
block_tables, num_decodes, num_decode_tokens, num_prefills,
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendMLATorchairMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.torchair_graph_enabled, True)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
input_batch.swap_states.assert_not_called()
def test_reorder_batch_without_torchair_graph(self):
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
scheduler_output.scheduled_spec_decode_tokens = {
0: [],
1: [1],
2: [],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@pytest.mark.skip(reason="Skipping this test temporarily.")
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_dummy(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 1)
self.assertEqual(metadata.num_decode_tokens, 1)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 64)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_decode(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
model = MagicMock(spec=nn.Module)
model.model = MagicMock(spec=nn.Module)
builder = AscendMLATorchairMetadataBuilder(
None,
None,
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
builder.sin_cache = torch.tensor([10, 10])
builder.cos_cache = torch.tensor([10, 10])
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([1, 1, 1]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
metadata = builder.build(1, common_attn_metadata, model)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 0)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 3)
self.assertEqual(metadata.num_decode_tokens, 3)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state,
AscendAttentionState.ChunkedPrefill)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 10)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 4)
class TestAscendMLATorchairImpl(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm.config.get_current_vllm_config")
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
mock_tp.world_size = 2
ascend_config.torchair_graph_config.enabled = True
ascend_config.torchair_graph_config.enable_kv_nz = False
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"q_lora_rank": 64,
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
}
self.impl = AscendMLATorchairImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.q_lora_rank, 64)
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.rotary_emb)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
self.assertTrue(self.impl.torchair_graph_enabled)
def test_v_up_proj_and_o_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
self.impl.o_proj.return_value = (torch.randn(
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj_and_o_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
self.impl.num_heads * self.impl.v_head_dim)
def test_q_proj_and_k_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
self.impl.q_proj.return_value = (q_proj_output, )
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
self.impl.qk_nope_head_dim,
self.impl.kv_lora_rank)
result = self.impl._q_proj_and_k_up_proj(x)
ql_nope, q_pe = result
self.assertEqual(ql_nope.shape[0], batch_size)
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
self.assertEqual(q_pe.shape[0], batch_size)
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
def test_compute_prefill_context_none(self):
batch_size = 4
kv_cache = torch.randn(10, 1, 1, 192)
query = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
metadata = MagicMock()
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
meta.prefill = prefill_meta
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv(self, mock_kv_cache):
batch_size = 2
hidden = torch.randn(batch_size, 128)
cos = torch.randn(batch_size, 32)
sin = torch.randn(batch_size, 32)
kv_cache = (torch.randn(
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(
4, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
slots = torch.arange(batch_size, dtype=torch.long)
proj_out = torch.randn(
batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv_cache.return_value = (torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim),
torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank),
None, None)
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
mock_kv_cache.assert_called_once()
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
self.assertEqual(kv.shape,
(batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv):
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
hidden_states = torch.randn(B, N, S, H)
cos = torch.randn(B, S, 32)
sin = torch.randn(B, S, 32)
kv_cache = (
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
)
slots = torch.arange(B * S, dtype=torch.long)
proj_out = torch.randn(
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv.return_value = (None, None,
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.qk_rope_head_dim),
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.kv_lora_rank))
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
mock_kv.assert_called_once()
self.assertEqual(
k_pe.shape,
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
@patch("torch_npu.npu_interleave_rope")
def test_rope_single(self, mock_rope):
B, N, D = 2, 16, 1024
x = torch.randn(B, N, D)
cos = torch.randn(B, N, 1, D)
sin = torch.randn(B, N, 1, D)
mock_rope.return_value = x.view(B, N, 1, D)
result = self.impl.rope_single(x, cos, sin)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], D)
mock_rope.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj"
)
@patch("torch_npu._npu_paged_attention_mla")
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
self.impl.num_heads,
self.impl.kv_lora_rank)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_page_attention_mla.return_value = torch.randn(
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, None, None,
kv_c_and_k_pe_cache, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill"
)
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -1,45 +0,0 @@
from unittest.mock import MagicMock, Mock
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.config import VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
class TestNPUTorchairModelRunner(PytestBase):
@pytest.fixture
def setup_npu_torchair_model_runner(self, mocker: MockerFixture):
mocker.patch.object(NPUTorchairModelRunner, "__init__",
lambda self, *args, **kwargs: None)
runner = NPUTorchairModelRunner(Mock(), Mock())
runner.device = torch.device("cpu")
runner.vllm_config = MagicMock(spec=VllmConfig)
runner.speculative_config = MagicMock(
method="mtp",
num_speculative_tokens=4,
disable_padded_drafter_batch=False)
runner.ascend_config = MagicMock(enable_shared_expert_dp=False,
torchair_graph_config=MagicMock(
use_cached_graph=True,
graph_batch_sizes=[1, 2, 4]))
runner.decode_token_per_req = 2
runner.is_kv_consumer = True
runner.max_num_reqs = 100
runner.model_config = MagicMock(hf_config=MagicMock(index_topk=2))
runner.attn_backend = MagicMock(get_builder_cls=lambda: Mock())
return runner
def test_init(self, mocker: MockerFixture,
setup_npu_torchair_model_runner):
runner = setup_npu_torchair_model_runner
assert isinstance(runner, NPUTorchairModelRunner)

View File

@@ -1,78 +0,0 @@
from unittest.mock import MagicMock, Mock
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.config import CacheConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
class TestTorchairMtpProposer(PytestBase):
@pytest.fixture
def setup_torchair_mtp_proposer(self, mocker: MockerFixture):
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.device_config = MagicMock()
vllm_config.device_config.device = torch.device("cpu")
vllm_config.speculative_config = MagicMock()
vllm_config.speculative_config.draft_model_config = MagicMock()
vllm_config.speculative_config.draft_model_config.dtype = torch.float16
vllm_config.speculative_config.method = "mtp"
vllm_config.speculative_config.num_speculative_tokens = 5
vllm_config.load_config = MagicMock()
cache_config = CacheConfig(block_size=16)
vllm_config.cache_config = cache_config
vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024,
max_num_seqs=64)
device = torch.device("cpu")
runner = MagicMock()
runner.pcp_size = 1
runner.dcp_size = 1
runner.pcp_rank = 0
runner.max_num_tokens = 1024
runner.max_num_reqs = 10
runner._use_aclgraph.return_value = True
mocker.patch(
"vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__",
return_value=None)
mock_set_default_dtype = mocker.patch(
'vllm.utils.torch_utils.set_default_torch_dtype')
mock_set_default_dtype.return_value.__enter__.return_value = None
mock_model_loader = MagicMock()
mocker.patch("vllm.model_executor.model_loader.get_model_loader",
return_value=mock_model_loader)
mock_layers = {
"target_attn_layer_1": Mock(),
"draft_attn_layer_2": Mock()
}
mocker.patch("vllm.config.get_layers_from_vllm_config",
return_value=mock_layers)
mock_set_current = mocker.patch("vllm.config.set_current_vllm_config")
mock_set_current.return_value.__enter__.return_value = None
mock_torchair_deepseek_mtp = MagicMock()
mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP",
return_value=mock_torchair_deepseek_mtp)
mocker.patch(
"vllm.model_executor.model_loader.utils.process_weights_after_loading"
)
proposer = TorchairMtpProposer(vllm_config, device, runner)
proposer.vllm_config = vllm_config
proposer.device = device
proposer.runner = runner
proposer.speculative_config = vllm_config.speculative_config
proposer.draft_model_config = vllm_config.speculative_config.draft_model_config
proposer.method = vllm_config.speculative_config.method
return proposer, mock_model_loader, mock_torchair_deepseek_mtp
def test_init(self, setup_torchair_mtp_proposer):
proposer, _, _, = setup_torchair_mtp_proposer
assert isinstance(proposer, TorchairMtpProposer)

View File

@@ -1,340 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.torchair_sfa import (
AscendSFATorchairBackend, AscendSFATorchairDecodeMetadata,
AscendSFATorchairImpl, AscendSFATorchairMetadata,
AscendSFATorchairMetadataBuilder, AscendSFATorchairPrefillMetadata)
class TestAscendSFATorchairBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendSFATorchairBackend.get_name(),
"ASCEND_SFA_TORCHAIR")
def test_get_builder_cls(self):
self.assertEqual(AscendSFATorchairBackend.get_builder_cls(),
AscendSFATorchairMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendSFATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendSFATorchairBackend.get_impl_cls()
self.assertEqual(result, AscendSFATorchairImpl)
class TestAscendSFATorchairPrefillMetadata(TestBase):
def test_ascend_sfa_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendSFATorchairPrefillMetadata(
attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
sin=None,
cos=None,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_sfa_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
metadata = AscendSFATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
sin=None,
cos=None,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
class TestAscendSFATorchairDecodeMetadata(TestBase):
def test_ascend_sfa_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendSFATorchairDecodeMetadata(input_positions,
block_table, seq_lens,
max_seq_lens, seq_lens_list,
None, None, attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendSFATorchairMetadata(TestBase):
def test_ascend_sfa_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendSFATorchairMetadata(
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
block_tables, num_decodes, num_decode_tokens, num_prefills,
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendSFATorchairMetadataBuilder(TestBase):
def test_ascend_sfa_metadata_builder_default(self):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config",
return_value=ascend_config):
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.torchair_graph_enabled, True)
self.assertEqual(builder.max_blocks, (mock_vllm_config.model_config.max_model_len +
mock_vllm_config.cache_config.block_size - 1) \
// mock_vllm_config.cache_config.block_size)
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
input_batch.swap_states.assert_not_called()
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_ge_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
builder.max_blocks = 4
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_model_config = MagicMock()
mock_model_config.max_model_len = 1024
mock_model_config.get_head_size.return_value = 64
mock_model_config.dtype = torch.float16
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = MagicMock(block_size=16)
mock_vllm_config.scheduler_config = MagicMock(
max_num_seqs=4, enable_chunked_prefill=False)
mock_vllm_config.speculative_config = None
mock_device = torch.device('cpu')
builder = AscendSFATorchairMetadataBuilder(None, None,
mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))

View File

@@ -1,111 +0,0 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from tests.ut.base import TestBase
init_cache_hf_modules_path = "vllm.utils.import_utils.init_cached_hf_modules"
class TestNPUTorchairWorker(TestBase):
def setUp(self):
self.cache_config_mock = MagicMock(spec=CacheConfig)
self.cache_config_mock.cache_type = "auto"
self.model_config_mock = MagicMock(spec=ModelConfig)
self.model_config_mock.dtype = torch.float16
self.model_config_mock.trust_remote_code = False
self.hf_config_mock = MagicMock()
self.hf_config_mock.model_type = "test_model"
if hasattr(self.hf_config_mock, 'index_topk'):
delattr(self.hf_config_mock, 'index_topk')
self.model_config_mock.hf_config = self.hf_config_mock
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
self.vllm_config_mock = MagicMock(spec=VllmConfig)
self.vllm_config_mock.cache_config = self.cache_config_mock
self.vllm_config_mock.model_config = self.model_config_mock
self.vllm_config_mock.parallel_config = self.parallel_config_mock
self.vllm_config_mock.additional_config = None
self.vllm_config_mock.load_config = None
self.vllm_config_mock.scheduler_config = None
self.vllm_config_mock.device_config = None
self.vllm_config_mock.compilation_config = None
self.local_rank = 0
self.rank = 0
self.distributed_init_method = "tcp://localhost:12345"
self.is_driver_worker = False
@patch(
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
)
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
def test_init_device(self, mock_platform, mock_init_dist_env):
from vllm_ascend.worker.worker_v1 import NPUWorker
mock_platform.mem_get_info.return_value = (1000, 2000)
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.local_rank = 1
worker.model_config = MagicMock()
worker.model_config.seed = 42
worker.vllm_config = MagicMock()
worker.parallel_config = MagicMock()
worker.parallel_config.local_world_size = 0
worker.parallel_config.data_parallel_size = 1
result = worker._init_device()
mock_platform.set_device.assert_called_once()
call_args = mock_platform.set_device.call_args[0][0]
self.assertEqual(str(call_args), "npu:1")
mock_platform.empty_cache.assert_called_once()
mock_platform.seed_everything.assert_called_once_with(42)
mock_platform.mem_get_info.assert_called_once()
mock_init_dist_env.assert_called_once()
self.assertEqual(str(result), "npu:1")
self.assertEqual(worker.init_npu_memory, 1000)
@patch(
"vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment"
)
@patch("vllm_ascend.worker.worker_v1.NPUPlatform")
def test_init_device_torchair_worker(self, mock_platform,
mock_init_dist_env):
from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker
mock_platform.mem_get_info.return_value = (1000, 2000)
with patch.object(NPUTorchairWorker, "__init__",
lambda x, **kwargs: None):
worker = NPUTorchairWorker()
worker.local_rank = 1
worker.model_config = MagicMock()
worker.model_config.seed = 42
worker.vllm_config = MagicMock()
worker.parallel_config = MagicMock()
worker.parallel_config.local_world_size = 0
worker.parallel_config.data_parallel_size = 1
result = worker._init_device()
mock_platform.set_device.assert_called_once()
call_args = mock_platform.set_device.call_args[0][0]
self.assertEqual(str(call_args), "npu:1")
mock_platform.empty_cache.assert_called_once()
mock_platform.seed_everything.assert_called_once_with(42)
mock_platform.mem_get_info.assert_called_once()
mock_init_dist_env.assert_called_once()
self.assertEqual(str(result), "npu:1")
self.assertEqual(worker.init_npu_memory, 1000)

View File

@@ -1,164 +0,0 @@
import os
from concurrent.futures import ThreadPoolExecutor
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair import utils
class TestTorchairUtils(TestBase):
def test_get_torchair_current_work_dir(self):
cache_dir = utils.TORCHAIR_CACHE_DIR
work_dir = utils._get_torchair_current_work_dir()
self.assertEqual(cache_dir, work_dir)
work_dir = utils._get_torchair_current_work_dir("test")
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
def test_torchair_cache_dir(self):
utils.write_kv_cache_bytes_to_file(0, 100)
self.assertTrue(utils.check_torchair_cache_exist(),
"Create torchair cache dir failed")
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
"Create kv cache bytes cache dir failed")
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
self.assertEqual(100, kv_cache_bytes)
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_torchair_cache_dir_multiple_ranks(self):
ranks = [0, 1, 2, 3]
values = [100, 200, 300, 400]
with ThreadPoolExecutor() as executor:
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
for rank, expected in zip(ranks, values):
self.assertEqual(expected,
utils.read_kv_cache_bytes_from_file(rank))
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_delete_torchair_cache_file_multiple_times(self):
utils.write_kv_cache_bytes_to_file(0, 100)
utils.delete_torchair_cache_file()
for i in range(5):
try:
utils.delete_torchair_cache_file()
except FileNotFoundError:
self.fail(
f"Unexpected FileNotFoundError on delete call #{i+2}")
@patch('vllm.ModelRegistry')
def test_register_torchair_model(self, mock_model_registry):
mock_registry = MagicMock()
mock_model_registry.return_value = mock_registry
utils.register_torchair_model()
self.assertEqual(mock_model_registry.register_model.call_count, 7)
call_args_list = mock_model_registry.register_model.call_args_list
expected_registrations = [
("DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
),
("DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
),
("DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
),
("DeepseekV32ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
),
("Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
("Qwen3MoeForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
),
("PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
]
for i, (expected_name,
expected_path) in enumerate(expected_registrations):
args, kwargs = call_args_list[i]
self.assertEqual(args[0], expected_name)
self.assertEqual(args[1], expected_path)
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(fused_moe.w13_weight.data, 1)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 0
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()