feat: support torchair graph mode in v1 engine (#789)
### What this PR does / why we need it? support torchair graph mode with v1 engine --------- Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
@@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.enable_graph_mode:
|
||||
return self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = hidden_states.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states_or_q_c.dtype,
|
||||
device=hidden_states_or_q_c.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
output = self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
if envs.VLLM_USE_V1:
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
else:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
@@ -653,4 +666,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user