[CORE]initial support for torchair with non-mla backend (#1506)

### What this PR does / why we need it?
This PR supports torchair graph mode with non-mla backend on both 800IA2
and 300I Duo platforms. The main change is to add
`attention_v1_torchair.py` to support specific attention related
operations that are required by torchair.

### Does this PR introduce _any_ user-facing change?
Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we
can also use it with pangu. Besides, we add a support model list to
control which type of models that can use torchair.

### How was this patch tested?
We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms,
and model generates answer normally.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Angazenn
2025-07-03 22:21:42 +08:00
committed by GitHub
parent 9fbd8017c0
commit a5f33590d3
19 changed files with 1130 additions and 84 deletions

View File

@@ -28,6 +28,6 @@ jobs:
- name: Run codespell check - name: Run codespell check
run: | run: |
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn') CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever')
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"

View File

@@ -86,7 +86,7 @@ jobs:
- name: Run codespell check - name: Run codespell check
run: | run: |
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn') CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever')
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
- name: Analysing the code with ruff - name: Analysing the code with ruff

View File

@@ -40,14 +40,14 @@ The details of each config option are as follows:
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- | | ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable torchair graph mode | | `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode |
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream | | `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert | | `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
| `use_cached_graph` | bool | `False` | Whether to use cached graph | | `use_cached_graph` | bool | `False` | Whether to use cached graph |
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout | | `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
**ascend_scheduler_config** **ascend_scheduler_config**

View File

@@ -12,7 +12,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa
There are two kinds for graph mode supported by vLLM Ascend: There are two kinds for graph mode supported by vLLM Ascend:
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, only Qwen series models are well tested. - **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, only Qwen series models are well tested.
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. - **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. In v0.9.1rc2, we also support PanguProMoe with torchair.
## Using ACLGraph ## Using ACLGraph
ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough. ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough.

View File

@@ -145,7 +145,7 @@ CODESPELL_EXCLUDES=(
) )
CODESPELL_IGNORE_WORDS=( CODESPELL_IGNORE_WORDS=(
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn' '-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever'
) )
# check spelling of specified files # check spelling of specified files

View File

@@ -165,3 +165,20 @@ def test_models_distributed_DeepSeek_W8A8():
quantization="ascend", quantization="ascend",
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_pangu():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
max_model_len=8192,
enforce_eager=True,
dtype="auto",
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -99,3 +99,63 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
}, },
} }
_deepseek_torchair_test_fixture(additional_config) _deepseek_torchair_test_fixture(additional_config)
def _pangu_torchair_test_fixture(
additional_config: Dict,
*,
tensor_parallel_size=4,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# torchair is only work without chunked-prefill now
kwargs = {
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}
additional_config.update(**kwargs)
with VllmRunner(
"vllm-ascend/pangu-pro-moe-pruing",
dtype="half",
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="mp",
enforce_eager=False,
additional_config=additional_config,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_pangu_with_torchair():
additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
_pangu_torchair_test_fixture(additional_config)

View File

@@ -0,0 +1,315 @@
import math
import unittest
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
native_rope_deepseek_forward,
rope_forward_oot, rotate_half,
yarn_find_correction_dim,
yarn_get_mscale)
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=False):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestRopeForwardOot(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
mock_get_ascend_config):
# Setup mock for torchair enabled
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.mock_self.forward_native.assert_called_once_with(
self.positions, self.query, self.key, None)
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C')
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock_get_ascend_config, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
non_contig_query, non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
rope_forward_oot(self.mock_self, self.positions, self.query,
self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = rope_forward_oot(self.mock_self,
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
class TestNativeRopeDeepseekForward(TestBase):
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache')
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_cache_handling(
self, mock_rope_forward_oot, mock_set_cache):
# Test cache situation is true
module = MockRopeModule(max_seq_len=1024)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module,
positions,
query,
key,
max_seq_len=2048)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == (1, 128)
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_rope_forward_oot):
module = MockRopeModule(is_neox_style=False)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
class TestRotateHalf(unittest.TestCase):
def test_rotate_half_even_dim(self):
# Test with even dimension
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
result = rotate_half(x)
self.assertTrue(torch.allclose(result, expected))
class TestYarnFindCorrectionDim(unittest.TestCase):
def test_basic_case(self):
# Test with standard values
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
class TestYarnGetMscale(unittest.TestCase):
def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
def test_scale_greater_than_1(self):
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -381,6 +381,58 @@ class TestAscendC8KVCacheMethod(TestBase):
self.assertEqual(mock_scatter.call_count, 2) self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output)) self.assertTrue(torch.equal(result, expected_output))
@patch('torch_npu.npu_scatter_nd_update_')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_attn_metadata_without_decode(self, mock_quant,
mock_scatter):
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock(spec=[
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
'attn_mask'
])
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
attn_metadata.seq_lens = [10, 10]
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
attn_metadata.slot_mapping = torch.tensor([0, 1])
attn_metadata.attn_mask = None
block_size = 16
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
kv_cache = (key_cache, value_cache)
mock_quant.side_effect = [key, value]
self.layer.key_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.layer.value_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.method.process_weights_after_loading(self.layer)
expected_output = torch.randn(
num_tokens, self.layer.num_heads * self.layer.head_size)
with patch('torch_npu.npu_incre_flash_attention',
return_value=expected_output):
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata,
self.attention_type.DECODER, 1.0,
output)
self.assertEqual(mock_quant.call_count, 2)
self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor") @patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('torch_npu._npu_flash_attention') @patch('torch_npu._npu_flash_attention')
def test_apply_prefill_no_cache(self, mock_flash, mock_quant): def test_apply_prefill_no_cache(self, mock_flash, mock_quant):

View File

@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm_ascend.ascend_config import (check_ascend_config, from vllm_ascend.ascend_config import (check_ascend_config,
check_torchair_supported,
clear_ascend_config, get_ascend_config, clear_ascend_config, get_ascend_config,
init_ascend_config) init_ascend_config)
@@ -242,3 +243,10 @@ class TestAscendConfig(unittest.TestCase):
test_vllm_config.model_config = fake_model_config test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config) init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False) check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
('qwen', False), ('llama', False)]
for model_type, expected_output in test_cases:
self.assertEqual(check_torchair_supported(model_type),
expected_output)

View File

@@ -292,23 +292,6 @@ class TestNPUPlatform(TestBase):
self.platform.check_and_update_config(self.mock_vllm_config) self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("Model config is missing" in cm.output[0]) self.assertTrue("Model config is missing" in cm.output[0])
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm.envs.VLLM_MLA_DISABLE", True)
def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
self.mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = self.mock_ascend_config
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled)
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@@ -502,7 +485,13 @@ class TestNPUPlatform(TestBase):
self.platform.check_and_update_config(self.mock_vllm_config) self.platform.check_and_update_config(self.mock_vllm_config)
mock_scheduler.initialize_from_config.assert_called_once() mock_scheduler.initialize_from_config.assert_called_once()
def test_get_attn_backend_cls_use_v1_and_mla(self): @patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls( result = self.platform.get_attn_backend_cls(
selected_backend="ascend", selected_backend="ascend",
head_size=64, head_size=64,
@@ -515,7 +504,35 @@ class TestNPUPlatform(TestBase):
self.assertEqual(result, self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend") "vllm_ascend.attention.mla_v1.AscendMLABackend")
def test_get_attn_backend_cls_use_v1_only(self): @patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls( result = self.platform.get_attn_backend_cls(
selected_backend="ascend", selected_backend="ascend",
head_size=64, head_size=64,
@@ -529,7 +546,13 @@ class TestNPUPlatform(TestBase):
result, result,
"vllm_ascend.attention.attention_v1.AscendAttentionBackend") "vllm_ascend.attention.attention_v1.AscendAttentionBackend")
def test_get_attn_backend_cls_use_mla_only(self): @patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls( result = self.platform.get_attn_backend_cls(
selected_backend="ascend", selected_backend="ascend",
head_size=64, head_size=64,
@@ -543,7 +566,13 @@ class TestNPUPlatform(TestBase):
result, result,
"vllm_ascend.attention.attention.AscendMLAAttentionBackend") "vllm_ascend.attention.attention.AscendMLAAttentionBackend")
def test_get_attn_backend_cls_default_case(self): @patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls( result = self.platform.get_attn_backend_cls(
selected_backend="ascend", selected_backend="ascend",
head_size=64, head_size=64,

View File

@@ -18,6 +18,15 @@ from typing import Optional
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import logger from vllm.logger import logger
TORCHAIR_MODEL_LIST = ["deepseek", "pangu"]
def check_torchair_supported(model_type: str):
for supported_model in TORCHAIR_MODEL_LIST:
if supported_model in model_type.lower():
return True
return False
class AscendConfig: class AscendConfig:
""" """
@@ -141,10 +150,10 @@ def check_ascend_config(vllm_config, enforce_eager):
# torchair_graph is supported for deepseek model only currently. # torchair_graph is supported for deepseek model only currently.
if vllm_config.model_config: if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" not in model_type: if not check_torchair_supported(model_type):
raise NotImplementedError( raise NotImplementedError(
"Torchair graph mode only works with deepseek model." "Torchair graph mode only works with following model types:"
) f"{TORCHAIR_MODEL_LIST}.")
# aclgraph case # aclgraph case
else: else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested. # aclgraph doesn't work with deepseek model and only qwen model is well tested.

View File

@@ -0,0 +1,506 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d)
class AscendAttentionTorchairBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
return AscendAttentionTorchairBackendImpl
@staticmethod
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
return AscendTorchairMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
return AscendAttentionTorchairMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]
for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@dataclass
class AscendDecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendTorchairMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: torch.Tensor
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# TODO: Indicates whether there are only prefill requests.
# FlashAttention can be used when there are only prefill requests.
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
decode: Optional[AscendDecodeMetadata] = None
class AscendAttentionTorchairMetadataBuilder:
def __init__(self, runner):
self.runner = runner
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
if isinstance(self.runner.graph_block_tables, np.ndarray):
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
else:
graph_block_tables = self.runner.graph_block_tables.to(
device=block_tables.device, dtype=block_tables.dtype)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
graph_block_tables[:num_seqs, :
num_blocks] = block_tables[:num_seqs, :
num_blocks]
else:
graph_block_tables[:num_seqs, :
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
attn_metadata = AscendTorchairMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_lens=0,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly,
max_num_tokens_across_dp=num_reqs,
decode=decode_metadata)
return attn_metadata
def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
device = self.runner.device
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True)
attn_mask = self.runner.attn_mask
attn_state = self.runner.attn_state
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
mask_nz = nd_to_nz_2d(attn_mask)
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
decode_metadata = None
use_torchair_graph = graph_pad_size > -1
if self.runner.attn_state in [
AscendAttentionState.DecodeOnly,
]:
max_seq_lens = seq_lens.max().item()
num_seqs = len(seq_lens)
if use_torchair_graph and self.runner.attn_state in [
AscendAttentionState.DecodeOnly,
]:
max_num_tokens_across_dp += graph_pad_size
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * graph_pad_size
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
padding = torch.full((graph_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, padding])
block_table_padding = torch.zeros(
(graph_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs + graph_pad_size, block_table)
padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
decode_metadata = AscendDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens,
attn_mask=None)
attn_metadata = AscendTorchairMetadata(
decode=decode_metadata,
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
return attn_metadata
class AscendAttentionTorchairBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendTorchairMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = False,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_quant = kv_cache is not None and kv_cache[0].numel(
) > 0 and kv_cache[0].dtype == torch.int8
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
output = layer.quant_method.apply(layer, query, key, value,
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)
return output.view(num_tokens, self.hidden_size)
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
output = output.view(-1, self.num_heads, self.head_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"AscendAttentionTorchairBackendImpl")
if kv_cache is not None and kv_cache[0].numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
block_size = key_cache.shape[1]
slots_indices = slots.reshape(-1, 1)
block_indices = slots_indices // block_size
slots_indices = slots_indices % block_size
indices = torch.cat((block_indices, slots_indices), dim=1)
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if is_310p():
# align q k v output tensors
query = aligned_16(query)
key = aligned_16(key)
value = aligned_16(value)
output = aligned_16(output)
# do reformat in case of broadcasted tensors
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
output = output[:num_tokens, :, :]
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
decode_meta = attn_metadata.decode
assert decode_meta is not None
seq_lens = decode_meta.seq_lens_list
block_table = decode_meta.block_table
block_size = key_cache.shape[1]
query = query.view(num_tokens, 1,
self.num_heads * self.head_size).contiguous()
output = torch_npu.npu_incre_flash_attention(
query,
key_cache,
value_cache,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
actual_seq_lengths=seq_lens,
scale_value=self.scale,
block_table=block_table,
input_layout='BSH',
block_size=block_size)
else:
raise NotImplementedError(
"Torchair graph mode with non-MLA attention backend is still experimental."
"v1 scheduler(chunked prefill) is not supported at this moment. Please"
"setting 'ascend_scheduler_config':{'enabled':true} in additional_config"
"to use ascend scheduler.")
return output.view(num_tokens, self.hidden_size)

View File

@@ -20,6 +20,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torch_npu
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
from transformers import PretrainedConfig from transformers import PretrainedConfig
@@ -56,8 +57,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.utils import is_310p from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -498,8 +500,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
global _ROUTER_SCALE global _ROUTER_SCALE
_ROUTER_SCALE = self.router_scale _ROUTER_SCALE = self.router_scale
if not use_h2p(): if not use_h2p():
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts.forward_impl(
router_logits=router_logits) hidden_states=hidden_states, router_logits=router_logits)
else: else:
# TODO: when using h2p, we have to skip communication in vLLM # TODO: when using h2p, we have to skip communication in vLLM
# native FusedMoE. here we need to design a better FusedMoE # native FusedMoE. here we need to design a better FusedMoE
@@ -608,6 +610,9 @@ class PanguProMoEAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
@@ -618,7 +623,19 @@ class PanguProMoEAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if self.torchair_graph_enabled:
forward_kwargs = {'trace_flag': False}
output_shape = q.shape
attn_output = torch.empty(output_shape,
dtype=q.dtype,
device=q.device)
forward_kwargs['output'] = attn_output
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
attn_metadata,
**forward_kwargs)
else:
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@@ -1097,4 +1114,10 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if is_310p() and "head" in name:
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
# by linear, we manually cast the format here.
param.data = torch_npu.npu_format_cast(param.data,
ACL_FORMAT_FRACTAL_NZ)
return loaded_params return loaded_params

View File

@@ -22,6 +22,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding) DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_custom_op, is_310p from vllm_ascend.utils import enable_custom_op, is_310p
@@ -38,6 +39,14 @@ def rope_forward_oot(
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None is_neox_style_override: Optional[bool] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config().torchair_graph_config.enabled:
return self.forward_native(
positions,
query,
key,
offsets,
)
import torch_npu import torch_npu
query_shape, key_shape = query.shape, key.shape query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device: if self.cos_sin_cache.device != query.device:

View File

@@ -132,19 +132,6 @@ def communication_adaptation_310p():
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p( torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
torch.distributed.distributed_c10d.all_reduce) torch.distributed.distributed_c10d.all_reduce)
def reduce_scatter_310p(output_tensor, input_tensor, group=None):
rank = torch.distributed.get_rank(group)
world_size = torch.distributed.get_world_size(group)
torch.distributed.all_reduce(input_tensor,
torch.distributed.ReduceOp.SUM,
group,
async_op=False)
interval = input_tensor.shape[0] // world_size
output_tensor[:] = input_tensor[rank * interval:(rank + 1) * interval]
torch.distributed._reduce_scatter_base = reduce_scatter_310p
torch.distributed.distributed_c10d._reduce_scatter_base = reduce_scatter_310p
if is_310p(): if is_310p():
communication_adaptation_310p() communication_adaptation_310p()

View File

@@ -27,7 +27,8 @@ from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum from vllm.platforms import Platform, PlatformEnum
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p, from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
update_aclgraph_sizes) update_aclgraph_sizes)
@@ -154,14 +155,6 @@ class NPUPlatform(Platform):
else: else:
enforce_eager = getattr(model_config, "enforce_eager", False) enforce_eager = getattr(model_config, "enforce_eager", False)
if ascend_config.torchair_graph_config.enabled and envs.VLLM_MLA_DISABLE:
# torchair_graph is not supported for V1 without mla currently.
logger.warning(
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
"Fallback to eager mode.")
ascend_config.torchair_graph_config.enabled = False
enforce_eager = True
check_ascend_config(vllm_config, enforce_eager) check_ascend_config(vllm_config, enforce_eager)
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION: if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
@@ -229,6 +222,9 @@ class NPUPlatform(Platform):
kv_cache_dtype, block_size, use_v1, use_mla): kv_cache_dtype, block_size, use_v1, use_mla):
if use_v1 and use_mla: if use_v1 and use_mla:
return "vllm_ascend.attention.mla_v1.AscendMLABackend" return "vllm_ascend.attention.mla_v1.AscendMLABackend"
use_torchair = get_ascend_config().torchair_graph_config.enabled
if use_v1 and use_torchair:
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
if use_v1: if use_v1:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend" return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
if use_mla: if use_mla:

View File

@@ -406,9 +406,11 @@ class AscendC8KVCacheMethod:
"implemented for " "implemented for "
"PrefillCacheHit") "PrefillCacheHit")
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
if hasattr(attn_metadata, "decode"):
# torch_air # torch_air
# decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
# seq_lens = decode_meta.seq_lens_list seq_lens = decode_meta.seq_lens_list
else:
seq_lens = attn_metadata.seq_lens seq_lens = attn_metadata.seq_lens
block_size = key_cache.shape[1] block_size = key_cache.shape[1]
query = query.view(num_tokens, 1, layer.num_heads * query = query.view(num_tokens, 1, layer.num_heads *

View File

@@ -2049,9 +2049,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
from torchair import patch_for_hcom # type: ignore from torchair import patch_for_hcom # type: ignore
patch_for_hcom() patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig() config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \ config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
@@ -2149,26 +2159,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size) kv_cache_spec.head_size)
if self.torchair_graph_enabled: if self.torchair_graph_enabled:
if len(kv_cache_shape) == 3:
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
kv_caches[layer_name] = (
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device),
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# atb reshape_and_cache does not support torchair.
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0],
ACL_FORMAT_FRACTAL_ND),
torch_npu.npu_format_cast(
kv_caches[layer_name][1],
ACL_FORMAT_FRACTAL_ND),
)
else:
# for MLA attention backend that use torchair.
layer_kv_cache_nope = torch.zeros( layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] + kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ), (self.model_config.hf_text_config.kv_lora_rank,
),
dtype=self.dtype, dtype=self.dtype,
pin_memory=True, pin_memory=True,
device=self.device) device=self.device)
layer_kv_cache_pe = torch.zeros( layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] + kv_cache_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim, (self.model_config.hf_text_config.
), qk_rope_head_dim, ),
dtype=self.dtype, dtype=self.dtype,
pin_memory=True, pin_memory=True,
device=self.device) device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope, kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe) layer_kv_cache_pe)
kv_caches[layer_name] = ( kv_caches[layer_name] = (
torch_npu.npu_format_cast(kv_caches[layer_name][0], torch_npu.npu_format_cast(
acl_format), kv_caches[layer_name][0], acl_format),
torch_npu.npu_format_cast(kv_caches[layer_name][1], torch_npu.npu_format_cast(
acl_format), kv_caches[layer_name][1], acl_format),
) )
else: else:
kv_caches[layer_name] = torch.zeros( kv_caches[layer_name] = torch.zeros(