[CORE]initial support for torchair with non-mla backend (#1506)
### What this PR does / why we need it? This PR supports torchair graph mode with non-mla backend on both 800IA2 and 300I Duo platforms. The main change is to add `attention_v1_torchair.py` to support specific attention related operations that are required by torchair. ### Does this PR introduce _any_ user-facing change? Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we can also use it with pangu. Besides, we add a support model list to control which type of models that can use torchair. ### How was this patch tested? We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms, and model generates answer normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -292,23 +292,6 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
self.assertTrue("Model config is missing" in cm.output[0])
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm.envs.VLLM_MLA_DISABLE", True)
|
||||
def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self.mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = self.mock_ascend_config
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
|
||||
self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@@ -502,7 +485,13 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
mock_scheduler.initialize_from_config.assert_called_once()
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -515,7 +504,35 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -529,7 +546,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_mla_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -543,7 +566,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention.AscendMLAAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_default_case(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
|
||||
Reference in New Issue
Block a user