[CORE]initial support for torchair with non-mla backend (#1506)

### What this PR does / why we need it?
This PR supports torchair graph mode with non-mla backend on both 800IA2
and 300I Duo platforms. The main change is to add
`attention_v1_torchair.py` to support specific attention related
operations that are required by torchair.

### Does this PR introduce _any_ user-facing change?
Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we
can also use it with pangu. Besides, we add a support model list to
control which type of models that can use torchair.

### How was this patch tested?
We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms,
and model generates answer normally.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Angazenn
2025-07-03 22:21:42 +08:00
committed by GitHub
parent 9fbd8017c0
commit a5f33590d3
19 changed files with 1130 additions and 84 deletions

View File

@@ -22,6 +22,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import enable_custom_op, is_310p
@@ -38,6 +39,14 @@ def rope_forward_oot(
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config().torchair_graph_config.enabled:
return self.forward_native(
positions,
query,
key,
offsets,
)
import torch_npu
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device: