[CORE]initial support for torchair with non-mla backend (#1506)
### What this PR does / why we need it? This PR supports torchair graph mode with non-mla backend on both 800IA2 and 300I Duo platforms. The main change is to add `attention_v1_torchair.py` to support specific attention related operations that are required by torchair. ### Does this PR introduce _any_ user-facing change? Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we can also use it with pangu. Besides, we add a support model list to control which type of models that can use torchair. ### How was this patch tested? We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms, and model generates answer normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -165,3 +165,20 @@ def test_models_distributed_DeepSeek_W8A8():
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
def test_models_distributed_pangu():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
@@ -99,3 +99,63 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
def _pangu_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=4,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# torchair is only work without chunked-prefill now
|
||||
kwargs = {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/pangu-pro-moe-pruing",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=False,
|
||||
additional_config=additional_config,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
|
||||
# with 2 hidden layers, thus the golden results seems inaccurate.
|
||||
# This will only change if accuracy changes with the official weights
|
||||
# of PanguProMoE.
|
||||
golden_results = [
|
||||
'Hello, my name is Remempondeprecatedmiot忱',
|
||||
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
|
||||
'The capital of France is Rememvoud administrativ Remem投',
|
||||
'The future of AI isotope Segnali Zoeken精细化 supus',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="torchair graph is not supported on v0")
|
||||
def test_e2e_pangu_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_pangu_torchair_test_fixture(additional_config)
|
||||
|
||||
315
tests/ut/ops/test_rotary_embedding.py
Normal file
315
tests/ut/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
|
||||
native_rope_deepseek_forward,
|
||||
rope_forward_oot, rotate_half,
|
||||
yarn_find_correction_dim,
|
||||
yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.mock_self.forward_native.assert_called_once_with(
|
||||
self.positions, self.query, self.key, None)
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
non_contig_query, non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
rope_forward_oot(self.mock_self, self.positions, self.query,
|
||||
self.key, offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = rope_forward_oot(self.mock_self,
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_rope_forward_oot, mock_set_cache):
|
||||
# Test cache situation is true
|
||||
module = MockRopeModule(max_seq_len=1024)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
max_seq_len=2048)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == (1, 128)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule(is_neox_style=False)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(unittest.TestCase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
|
||||
result = rotate_half(x)
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(unittest.TestCase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(unittest.TestCase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
def test_scale_greater_than_1(self):
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
@@ -381,6 +381,58 @@ class TestAscendC8KVCacheMethod(TestBase):
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_attn_metadata_without_decode(self, mock_quant,
|
||||
mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock(spec=[
|
||||
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
|
||||
'attn_mask'
|
||||
])
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config,
|
||||
check_torchair_supported,
|
||||
clear_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
|
||||
@@ -242,3 +243,10 @@ class TestAscendConfig(unittest.TestCase):
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
def test_check_torchair_supported(self):
|
||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||
('qwen', False), ('llama', False)]
|
||||
for model_type, expected_output in test_cases:
|
||||
self.assertEqual(check_torchair_supported(model_type),
|
||||
expected_output)
|
||||
|
||||
@@ -292,23 +292,6 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
self.assertTrue("Model config is missing" in cm.output[0])
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm.envs.VLLM_MLA_DISABLE", True)
|
||||
def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self.mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = self.mock_ascend_config
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
|
||||
self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@@ -502,7 +485,13 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
mock_scheduler.initialize_from_config.assert_called_once()
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -515,7 +504,35 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -529,7 +546,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_mla_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -543,7 +566,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention.AscendMLAAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_default_case(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
|
||||
Reference in New Issue
Block a user