[CORE]initial support for torchair with non-mla backend (#1506)

### What this PR does / why we need it?
This PR supports torchair graph mode with non-mla backend on both 800IA2
and 300I Duo platforms. The main change is to add
`attention_v1_torchair.py` to support specific attention related
operations that are required by torchair.

### Does this PR introduce _any_ user-facing change?
Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we
can also use it with pangu. Besides, we add a support model list to
control which type of models that can use torchair.

### How was this patch tested?
We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms,
and model generates answer normally.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Angazenn
2025-07-03 22:21:42 +08:00
committed by GitHub
parent 9fbd8017c0
commit a5f33590d3
19 changed files with 1130 additions and 84 deletions

View File

@@ -20,6 +20,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from torch.nn import Parameter
from transformers import PretrainedConfig
@@ -56,8 +57,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.utils import is_310p
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
logger = init_logger(__name__)
@@ -498,8 +500,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
global _ROUTER_SCALE
_ROUTER_SCALE = self.router_scale
if not use_h2p():
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = self.experts.forward_impl(
hidden_states=hidden_states, router_logits=router_logits)
else:
# TODO: when using h2p, we have to skip communication in vLLM
# native FusedMoE. here we need to design a better FusedMoE
@@ -608,6 +610,9 @@ class PanguProMoEAttention(nn.Module):
prefix=f"{prefix}.attn",
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
def forward(
self,
positions: torch.Tensor,
@@ -618,7 +623,19 @@ class PanguProMoEAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if self.torchair_graph_enabled:
forward_kwargs = {'trace_flag': False}
output_shape = q.shape
attn_output = torch.empty(output_shape,
dtype=q.dtype,
device=q.device)
forward_kwargs['output'] = attn_output
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
attn_metadata,
**forward_kwargs)
else:
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -1097,4 +1114,10 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if is_310p() and "head" in name:
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
# by linear, we manually cast the format here.
param.data = torch_npu.npu_format_cast(param.data,
ACL_FORMAT_FRACTAL_NZ)
return loaded_params