[CORE]initial support for torchair with non-mla backend (#1506)
### What this PR does / why we need it? This PR supports torchair graph mode with non-mla backend on both 800IA2 and 300I Duo platforms. The main change is to add `attention_v1_torchair.py` to support specific attention related operations that are required by torchair. ### Does this PR introduce _any_ user-facing change? Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we can also use it with pangu. Besides, we add a support model list to control which type of models that can use torchair. ### How was this patch tested? We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms, and model generates answer normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
@@ -56,8 +57,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -498,8 +500,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
global _ROUTER_SCALE
|
||||
_ROUTER_SCALE = self.router_scale
|
||||
if not use_h2p():
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
final_hidden_states = self.experts.forward_impl(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
else:
|
||||
# TODO: when using h2p, we have to skip communication in vLLM
|
||||
# native FusedMoE. here we need to design a better FusedMoE
|
||||
@@ -608,6 +610,9 @@ class PanguProMoEAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -618,7 +623,19 @@ class PanguProMoEAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
if self.torchair_graph_enabled:
|
||||
forward_kwargs = {'trace_flag': False}
|
||||
output_shape = q.shape
|
||||
attn_output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = attn_output
|
||||
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
else:
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -1097,4 +1114,10 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
if is_310p() and "head" in name:
|
||||
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
|
||||
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
|
||||
# by linear, we manually cast the format here.
|
||||
param.data = torch_npu.npu_format_cast(param.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user