[CORE]initial support for torchair with non-mla backend (#1506)
### What this PR does / why we need it? This PR supports torchair graph mode with non-mla backend on both 800IA2 and 300I Duo platforms. The main change is to add `attention_v1_torchair.py` to support specific attention related operations that are required by torchair. ### Does this PR introduce _any_ user-facing change? Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we can also use it with pangu. Besides, we add a support model list to control which type of models that can use torchair. ### How was this patch tested? We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms, and model generates answer normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
2
.github/workflows/doc_codespell.yaml
vendored
2
.github/workflows/doc_codespell.yaml
vendored
@@ -28,6 +28,6 @@ jobs:
|
||||
- name: Run codespell check
|
||||
run: |
|
||||
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
|
||||
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn')
|
||||
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever')
|
||||
|
||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
|
||||
|
||||
2
.github/workflows/vllm_ascend_test.yaml
vendored
2
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -86,7 +86,7 @@ jobs:
|
||||
- name: Run codespell check
|
||||
run: |
|
||||
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
|
||||
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn')
|
||||
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever')
|
||||
|
||||
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
|
||||
- name: Analysing the code with ruff
|
||||
|
||||
@@ -40,14 +40,14 @@ The details of each config option are as follows:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
| ---- | ---- | ------- | ----------- |
|
||||
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
|
||||
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
|
||||
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
|
||||
| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode |
|
||||
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. |
|
||||
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
|
||||
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
||||
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
||||
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
||||
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
|
||||
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||
|
||||
**ascend_scheduler_config**
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa
|
||||
|
||||
There are two kinds for graph mode supported by vLLM Ascend:
|
||||
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, only Qwen series models are well tested.
|
||||
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported.
|
||||
- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. In v0.9.1rc2, we also support PanguProMoe with torchair.
|
||||
|
||||
## Using ACLGraph
|
||||
ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough.
|
||||
|
||||
@@ -145,7 +145,7 @@ CODESPELL_EXCLUDES=(
|
||||
)
|
||||
|
||||
CODESPELL_IGNORE_WORDS=(
|
||||
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn'
|
||||
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever'
|
||||
)
|
||||
|
||||
# check spelling of specified files
|
||||
|
||||
@@ -165,3 +165,20 @@ def test_models_distributed_DeepSeek_W8A8():
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
def test_models_distributed_pangu():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/pangu-pro-moe-pruing"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
@@ -99,3 +99,63 @@ def test_e2e_deepseekv3_with_torchair_ms_mla():
|
||||
},
|
||||
}
|
||||
_deepseek_torchair_test_fixture(additional_config)
|
||||
|
||||
|
||||
def _pangu_torchair_test_fixture(
|
||||
additional_config: Dict,
|
||||
*,
|
||||
tensor_parallel_size=4,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# torchair is only work without chunked-prefill now
|
||||
kwargs = {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
additional_config.update(**kwargs)
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/pangu-pro-moe-pruing",
|
||||
dtype="half",
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=False,
|
||||
additional_config=additional_config,
|
||||
) as vllm_model:
|
||||
# use greedy sampler to make sure the generated results are fix
|
||||
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
|
||||
|
||||
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
|
||||
# with 2 hidden layers, thus the golden results seems inaccurate.
|
||||
# This will only change if accuracy changes with the official weights
|
||||
# of PanguProMoE.
|
||||
golden_results = [
|
||||
'Hello, my name is Remempondeprecatedmiot忱',
|
||||
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
|
||||
'The capital of France is Rememvoud administrativ Remem投',
|
||||
'The future of AI isotope Segnali Zoeken精细化 supus',
|
||||
]
|
||||
|
||||
assert len(golden_results) == len(vllm_output)
|
||||
for i in range(len(vllm_output)):
|
||||
assert golden_results[i] == vllm_output[i][1]
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="torchair graph is not supported on v0")
|
||||
def test_e2e_pangu_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
_pangu_torchair_test_fixture(additional_config)
|
||||
|
||||
315
tests/ut/ops/test_rotary_embedding.py
Normal file
315
tests/ut/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
|
||||
native_rope_deepseek_forward,
|
||||
rope_forward_oot, rotate_half,
|
||||
yarn_find_correction_dim,
|
||||
yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.mock_self.forward_native.assert_called_once_with(
|
||||
self.positions, self.query, self.key, None)
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
non_contig_query, non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
rope_forward_oot(self.mock_self, self.positions, self.query,
|
||||
self.key, offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = rope_forward_oot(self.mock_self,
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_rope_forward_oot, mock_set_cache):
|
||||
# Test cache situation is true
|
||||
module = MockRopeModule(max_seq_len=1024)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
max_seq_len=2048)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == (1, 128)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule(is_neox_style=False)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(unittest.TestCase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
|
||||
result = rotate_half(x)
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(unittest.TestCase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(unittest.TestCase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
def test_scale_greater_than_1(self):
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
@@ -381,6 +381,58 @@ class TestAscendC8KVCacheMethod(TestBase):
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_attn_metadata_without_decode(self, mock_quant,
|
||||
mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock(spec=[
|
||||
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
|
||||
'attn_mask'
|
||||
])
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config,
|
||||
check_torchair_supported,
|
||||
clear_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
|
||||
@@ -242,3 +243,10 @@ class TestAscendConfig(unittest.TestCase):
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
def test_check_torchair_supported(self):
|
||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||
('qwen', False), ('llama', False)]
|
||||
for model_type, expected_output in test_cases:
|
||||
self.assertEqual(check_torchair_supported(model_type),
|
||||
expected_output)
|
||||
|
||||
@@ -292,23 +292,6 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
self.assertTrue("Model config is missing" in cm.output[0])
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm.envs.VLLM_MLA_DISABLE", True)
|
||||
def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
self.mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = self.mock_ascend_config
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
|
||||
self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@@ -502,7 +485,13 @@ class TestNPUPlatform(TestBase):
|
||||
self.platform.check_and_update_config(self.mock_vllm_config)
|
||||
mock_scheduler.initialize_from_config.assert_called_once()
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -515,7 +504,35 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_v1_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -529,7 +546,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_use_mla_only(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
@@ -543,7 +566,13 @@ class TestNPUPlatform(TestBase):
|
||||
result,
|
||||
"vllm_ascend.attention.attention.AscendMLAAttentionBackend")
|
||||
|
||||
def test_get_attn_backend_cls_default_case(self):
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
|
||||
@@ -18,6 +18,15 @@ from typing import Optional
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
|
||||
TORCHAIR_MODEL_LIST = ["deepseek", "pangu"]
|
||||
|
||||
|
||||
def check_torchair_supported(model_type: str):
|
||||
for supported_model in TORCHAIR_MODEL_LIST:
|
||||
if supported_model in model_type.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AscendConfig:
|
||||
"""
|
||||
@@ -141,10 +150,10 @@ def check_ascend_config(vllm_config, enforce_eager):
|
||||
# torchair_graph is supported for deepseek model only currently.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "deepseek" not in model_type:
|
||||
if not check_torchair_supported(model_type):
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode only works with deepseek model."
|
||||
)
|
||||
"Torchair graph mode only works with following model types:"
|
||||
f"{TORCHAIR_MODEL_LIST}.")
|
||||
# aclgraph case
|
||||
else:
|
||||
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
|
||||
|
||||
506
vllm_ascend/attention/attention_v1_torchair.py
Normal file
506
vllm_ascend/attention/attention_v1_torchair.py
Normal file
@@ -0,0 +1,506 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
||||
return AscendAttentionTorchairBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AscendTorchairMetadata"]:
|
||||
return AscendTorchairMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||
return AscendAttentionTorchairMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
dst_indices = src_to_dists[:, 1]
|
||||
|
||||
for kv_cache in kv_caches:
|
||||
key_caches = kv_cache[0]
|
||||
value_caches = kv_cache[1]
|
||||
key_caches[dst_indices] = key_caches[src_indices]
|
||||
value_caches[dst_indices] = value_caches[src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendTorchairMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: torch.Tensor
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
query_start_loc: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# max value of number of tokens across dp group
|
||||
max_num_tokens_across_dp: int = 0
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor = None
|
||||
# TODO: Indicates whether there are only prefill requests.
|
||||
# FlashAttention can be used when there are only prefill requests.
|
||||
# FlashAttention has better performance than PageAtttention,
|
||||
# but it does not support decode requests.
|
||||
is_only_prefill: bool = False
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
with_prefill_across_dp: bool = False
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder:
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self.torchair_graph_enabled = get_ascend_config(
|
||||
).torchair_graph_config.enabled
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
else:
|
||||
graph_block_tables = self.runner.graph_block_tables.to(
|
||||
device=block_tables.device, dtype=block_tables.dtype)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[:num_seqs, :
|
||||
num_blocks] = block_tables[:num_seqs, :
|
||||
num_blocks]
|
||||
else:
|
||||
graph_block_tables[:num_seqs, :
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
|
||||
def build_dummy(self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendTorchairMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc = torch.full((num_reqs, ),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
max_num_tokens_across_dp=num_reqs,
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build(self,
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
common_prefix_len,
|
||||
graph_pad_size: int = -1,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False):
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
query_lens = self.runner.query_lens
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
attn_mask = self.runner.attn_mask
|
||||
|
||||
attn_state = self.runner.attn_state
|
||||
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
||||
non_blocking=True)
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
decode_metadata = None
|
||||
use_torchair_graph = graph_pad_size > -1
|
||||
if self.runner.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
num_seqs = len(seq_lens)
|
||||
if use_torchair_graph and self.runner.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_num_tokens_across_dp += graph_pad_size
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((graph_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(graph_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + graph_pad_size, block_table)
|
||||
padding_0 = torch.zeros(graph_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=None)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
decode=decode_metadata,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendTorchairMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_quant = kv_cache is not None and kv_cache[0].numel(
|
||||
) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"AscendAttentionTorchairBackendImpl")
|
||||
|
||||
if kv_cache is not None and kv_cache[0].numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if is_310p():
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
output = output[:num_tokens, :, :]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=attn_metadata.block_tables,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
block_table = decode_meta.block_table
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1,
|
||||
self.num_heads * self.head_size).contiguous()
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
actual_seq_lengths=seq_lens,
|
||||
scale_value=self.scale,
|
||||
block_table=block_table,
|
||||
input_layout='BSH',
|
||||
block_size=block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode with non-MLA attention backend is still experimental."
|
||||
"v1 scheduler(chunked prefill) is not supported at this moment. Please"
|
||||
"setting 'ascend_scheduler_config':{'enabled':true} in additional_config"
|
||||
"to use ascend scheduler.")
|
||||
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
@@ -20,6 +20,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
@@ -56,8 +57,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -498,8 +500,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
global _ROUTER_SCALE
|
||||
_ROUTER_SCALE = self.router_scale
|
||||
if not use_h2p():
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
final_hidden_states = self.experts.forward_impl(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
else:
|
||||
# TODO: when using h2p, we have to skip communication in vLLM
|
||||
# native FusedMoE. here we need to design a better FusedMoE
|
||||
@@ -608,6 +610,9 @@ class PanguProMoEAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -618,7 +623,19 @@ class PanguProMoEAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
if self.torchair_graph_enabled:
|
||||
forward_kwargs = {'trace_flag': False}
|
||||
output_shape = q.shape
|
||||
attn_output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = attn_output
|
||||
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
else:
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -1097,4 +1114,10 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
if is_310p() and "head" in name:
|
||||
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
|
||||
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
|
||||
# by linear, we manually cast the format here.
|
||||
param.data = torch_npu.npu_format_cast(param.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
return loaded_params
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
|
||||
|
||||
@@ -38,6 +39,14 @@ def rope_forward_oot(
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
return self.forward_native(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
)
|
||||
|
||||
import torch_npu
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
|
||||
@@ -132,19 +132,6 @@ def communication_adaptation_310p():
|
||||
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
|
||||
torch.distributed.distributed_c10d.all_reduce)
|
||||
|
||||
def reduce_scatter_310p(output_tensor, input_tensor, group=None):
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
torch.distributed.all_reduce(input_tensor,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
group,
|
||||
async_op=False)
|
||||
interval = input_tensor.shape[0] // world_size
|
||||
output_tensor[:] = input_tensor[rank * interval:(rank + 1) * interval]
|
||||
|
||||
torch.distributed._reduce_scatter_base = reduce_scatter_310p
|
||||
torch.distributed.distributed_c10d._reduce_scatter_base = reduce_scatter_310p
|
||||
|
||||
|
||||
if is_310p():
|
||||
communication_adaptation_310p()
|
||||
|
||||
@@ -27,7 +27,8 @@ from torch.distributed.distributed_c10d import PrefixStore
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
|
||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
|
||||
@@ -154,14 +155,6 @@ class NPUPlatform(Platform):
|
||||
else:
|
||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||
|
||||
if ascend_config.torchair_graph_config.enabled and envs.VLLM_MLA_DISABLE:
|
||||
# torchair_graph is not supported for V1 without mla currently.
|
||||
logger.warning(
|
||||
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
|
||||
"Fallback to eager mode.")
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
enforce_eager = True
|
||||
|
||||
check_ascend_config(vllm_config, enforce_eager)
|
||||
|
||||
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
||||
@@ -229,6 +222,9 @@ class NPUPlatform(Platform):
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
if use_v1 and use_mla:
|
||||
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
if use_v1 and use_torchair:
|
||||
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
||||
if use_v1:
|
||||
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
if use_mla:
|
||||
|
||||
@@ -406,10 +406,12 @@ class AscendC8KVCacheMethod:
|
||||
"implemented for "
|
||||
"PrefillCacheHit")
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
# torch_air
|
||||
# decode_meta = attn_metadata.decode
|
||||
# seq_lens = decode_meta.seq_lens_list
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
if hasattr(attn_metadata, "decode"):
|
||||
# torch_air
|
||||
decode_meta = attn_metadata.decode
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
else:
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1, layer.num_heads *
|
||||
layer.head_size).contiguous() # changed
|
||||
|
||||
@@ -2049,9 +2049,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
@@ -2149,27 +2159,50 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
if self.torchair_graph_enabled:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][0],
|
||||
acl_format),
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][1],
|
||||
acl_format),
|
||||
)
|
||||
if len(kv_cache_shape) == 3:
|
||||
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
|
||||
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
|
||||
|
||||
kv_caches[layer_name] = (
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device),
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
# atb reshape_and_cache does not support torchair.
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
)
|
||||
else:
|
||||
# for MLA attention backend that use torchair.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.
|
||||
qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0], acl_format),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1], acl_format),
|
||||
)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape,
|
||||
|
||||
Reference in New Issue
Block a user