feat: support torchair graph mode in v1 engine (#789)

### What this PR does / why we need it?
support torchair graph mode with v1 engine

---------

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-12 19:14:07 +08:00
committed by GitHub
parent 4a2505f81f
commit efabd722eb
5 changed files with 585 additions and 82 deletions

View File

@@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
import torch_npu
import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
@@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
else:
hidden_states_or_q_c = hidden_states
if self.enable_graph_mode:
return self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
hidden_states, None, kv_cache,
attn_metadata)
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape
output = torch.empty(output_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
forward_kwargs['output'] = output
output = self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
hidden_states, None, kv_cache,
attn_metadata,
**forward_kwargs)
if envs.VLLM_USE_V1:
output = output.view(-1, output_shape[-1])
return output
else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
@@ -653,4 +666,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass
pass