feat: support torchair graph mode in v1 engine (#789)
### What this PR does / why we need it? support torchair graph mode with v1 engine --------- Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@@ -51,6 +54,7 @@ class AscendMLAPrefillMetadata:
|
|||||||
""" Prefill Specific Metadata for Ascend"""
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: list[int]
|
query_lens: list[int]
|
||||||
|
seq_lens: list[int]
|
||||||
context_lens: torch.Tensor
|
context_lens: torch.Tensor
|
||||||
input_positions: torch.Tensor
|
input_positions: torch.Tensor
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -66,6 +70,7 @@ class AscendMLADecodeMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
max_seq_lens: int
|
max_seq_lens: int
|
||||||
|
seq_lens_list: list[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -195,11 +200,38 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
return modified_batch
|
return modified_batch
|
||||||
|
|
||||||
|
def _get_graph_runner_block_tables(
|
||||||
|
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||||
|
assert max_batch_size >= num_seqs
|
||||||
|
|
||||||
|
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||||
|
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||||
|
dtype=block_tables.dtype,
|
||||||
|
device=block_tables.device)
|
||||||
|
else:
|
||||||
|
graph_block_tables = self.runner.graph_block_tables.to(
|
||||||
|
device=block_tables.device, dtype=block_tables.dtype)
|
||||||
|
|
||||||
|
num_blocks = block_tables.size(1)
|
||||||
|
if num_blocks <= max_blocks:
|
||||||
|
graph_block_tables[:num_seqs, :
|
||||||
|
num_blocks] = block_tables[:num_seqs, :
|
||||||
|
num_blocks]
|
||||||
|
else:
|
||||||
|
graph_block_tables[:num_seqs, :
|
||||||
|
max_blocks] = block_tables[:num_seqs, :
|
||||||
|
max_blocks]
|
||||||
|
|
||||||
|
return graph_block_tables
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
num_actual_tokens: int,
|
num_actual_tokens: int,
|
||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata:
|
common_prefix_len: Optional[int] = None,
|
||||||
|
graph_pad_size: int = -1) -> AscendMLAMetadata:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
@@ -230,6 +262,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
prefill_metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=self.runner.attn_mask,
|
attn_mask=self.runner.attn_mask,
|
||||||
query_lens=query_lens[tokens_start:],
|
query_lens=query_lens[tokens_start:],
|
||||||
|
seq_lens=seq_lens,
|
||||||
context_lens=seq_lens[tokens_start:],
|
context_lens=seq_lens[tokens_start:],
|
||||||
input_positions=input_positions[tokens_start:],
|
input_positions=input_positions[tokens_start:],
|
||||||
block_table=block_table[reqs_start:, ...],
|
block_table=block_table[reqs_start:, ...],
|
||||||
@@ -238,12 +271,46 @@ class AscendMLAMetadataBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
|
use_torchair_graph = graph_pad_size != -1
|
||||||
if self._num_decodes > 0:
|
if self._num_decodes > 0:
|
||||||
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
||||||
|
seq_lens = seq_lens[:self._num_decode_tokens]
|
||||||
|
input_positions = input_positions[:self._num_decode_tokens]
|
||||||
|
block_table = block_table[:self._num_decode_tokens, ...]
|
||||||
|
if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
if graph_pad_size != 0:
|
||||||
|
pad_value = 1
|
||||||
|
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||||
|
] * graph_pad_size
|
||||||
|
else:
|
||||||
|
padded_seq_lens = seq_lens.tolist()
|
||||||
|
|
||||||
|
seq_lens = torch.from_numpy(
|
||||||
|
np.array(padded_seq_lens).astype(np.int32))
|
||||||
|
padding = torch.full((graph_pad_size, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=slot_mapping.dtype,
|
||||||
|
device=slot_mapping.device)
|
||||||
|
slot_mapping = torch.cat([slot_mapping, padding])
|
||||||
|
block_table_padding = torch.zeros(
|
||||||
|
(graph_pad_size, ) + block_table.shape[1:],
|
||||||
|
dtype=block_table.dtype,
|
||||||
|
device=block_table.device)
|
||||||
|
block_table = torch.cat([block_table, block_table_padding],
|
||||||
|
dim=0)
|
||||||
|
block_table = self._get_graph_runner_block_tables(
|
||||||
|
num_seqs, block_table)
|
||||||
|
padding_0 = torch.zeros(graph_pad_size,
|
||||||
|
dtype=input_positions.dtype,
|
||||||
|
device=input_positions.device)
|
||||||
|
input_positions = torch.cat([input_positions, padding_0])
|
||||||
|
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
decode_metadata = AscendMLADecodeMetadata(
|
||||||
input_positions=input_positions[:self._num_decode_tokens],
|
input_positions=input_positions,
|
||||||
block_table=block_table[:self._num_decode_tokens, ...],
|
block_table=block_table,
|
||||||
seq_lens=seq_lens[:self._num_decode_tokens],
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_list=seq_lens.tolist(),
|
||||||
max_seq_lens=max_seq_lens)
|
max_seq_lens=max_seq_lens)
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
@@ -323,6 +390,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||||
|
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
# latter has an additional parameter to control FA2 vs FA3
|
# latter has an additional parameter to control FA2 vs FA3
|
||||||
@@ -332,6 +401,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# functools.partial(flash_attn_varlen_func,
|
# functools.partial(flash_attn_varlen_func,
|
||||||
# fa_version=self.vllm_flash_attn_version)
|
# fa_version=self.vllm_flash_attn_version)
|
||||||
|
|
||||||
|
self.enable_graph_mode = False
|
||||||
|
additional_config = get_current_vllm_config().additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode", False)
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
@@ -485,15 +560,55 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
return self.o_proj(attn_output)[0]
|
return self.o_proj(attn_output)[0]
|
||||||
|
|
||||||
|
def exec_kv(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
kv_cache: Tuple,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
|
||||||
|
B = hidden_states.shape[0]
|
||||||
|
N = self.num_kv_heads
|
||||||
|
S = 1
|
||||||
|
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||||
|
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||||
|
kv,
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
slots.to(torch.int64),
|
||||||
|
kv_cache[1],
|
||||||
|
kv_cache[0],
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode="PA",
|
||||||
|
)
|
||||||
|
return k_pe, k_nope
|
||||||
|
|
||||||
|
def rope_single(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, N, D = x.shape
|
||||||
|
S = 1
|
||||||
|
x = x.view(B, N, S, D)
|
||||||
|
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
||||||
|
return x.view(B, N, D)
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: AscendMLAMetadata,
|
attn_metadata: AscendMLAMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
|
||||||
|
|
||||||
decode_meta = attn_metadata.decode
|
decode_meta = attn_metadata.decode
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
|
|
||||||
@@ -503,6 +618,36 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||||
dtype=q.dtype,
|
dtype=q.dtype,
|
||||||
device=q.device)
|
device=q.device)
|
||||||
|
if self.running_in_graph:
|
||||||
|
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||||
|
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||||
|
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||||
|
# shape of knope/k_pe for npu graph mode should be:
|
||||||
|
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||||
|
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||||
|
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||||
|
self.kv_lora_rank)
|
||||||
|
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||||
|
self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
|
q_nope,
|
||||||
|
k_nope,
|
||||||
|
k_nope,
|
||||||
|
query_rope=q_pe,
|
||||||
|
key_rope=k_pe,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
input_layout="BNSD",
|
||||||
|
atten_mask=attn_metadata.attn_mask,
|
||||||
|
scale=self.scale,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
block_table=decode_meta.block_table,
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||||
|
)
|
||||||
|
else:
|
||||||
torch_npu._npu_paged_attention_mla(
|
torch_npu._npu_paged_attention_mla(
|
||||||
query=q,
|
query=q,
|
||||||
key_cache=kv_c_and_k_pe_cache,
|
key_cache=kv_c_and_k_pe_cache,
|
||||||
@@ -519,85 +664,135 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||||
k_c_normed: torch.Tensor, # key in unified attn
|
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
||||||
k_pe: torch.Tensor, # value in unified attn
|
k_pe: torch.Tensor, # value in unified attn
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output
|
||||||
|
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
|
if k_pe is None and not self.running_in_graph:
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||||
output_padded = output
|
hidden_states_or_kv_c_normed)[0].split(
|
||||||
output = output[:num_actual_toks, ...]
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
else:
|
||||||
k_pe = k_pe[:num_actual_toks, ...]
|
kv_c_normed = hidden_states_or_kv_c_normed
|
||||||
|
|
||||||
# Restore head dim (for rotary embedding)
|
|
||||||
k_pe = k_pe.unsqueeze(1)
|
|
||||||
|
|
||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
attn_metadata.num_prefills is not None and \
|
attn_metadata.num_prefills is not None and \
|
||||||
attn_metadata.num_decode_tokens is not None
|
attn_metadata.num_decode_tokens is not None
|
||||||
|
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
if not self.running_in_graph:
|
||||||
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
|
output_padded = output
|
||||||
|
output = output[:num_actual_toks, ...]
|
||||||
|
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
||||||
|
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
||||||
|
if not self.running_in_graph:
|
||||||
|
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||||
decode_k_pe = k_pe[:num_decode_tokens]
|
|
||||||
|
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||||
|
k_pe = k_pe[:num_actual_toks, ...]
|
||||||
|
k_pe = k_pe.unsqueeze(1)
|
||||||
|
decode_k_pe = k_pe[:num_decode_tokens]
|
||||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
else:
|
||||||
|
decode_hs_or_q_c = hidden_states_or_q_c
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
decode_k_nope = None
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
decode_ql_nope, decode_q_pe = \
|
decode_ql_nope, decode_q_pe = \
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||||
|
if self.running_in_graph:
|
||||||
|
seq_len = self.rotary_emb.max_position_embeddings
|
||||||
|
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||||
|
dtype=decode_q_pe.dtype)
|
||||||
|
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||||
|
dtype=decode_q_pe.dtype)
|
||||||
|
cos = cos[attn_metadata.decode.input_positions]
|
||||||
|
sin = sin[attn_metadata.decode.input_positions]
|
||||||
|
cos = cos[:, None, None, :]
|
||||||
|
sin = sin[:, None, None, :]
|
||||||
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||||
|
decode_k_pe, decode_k_nope = self.exec_kv(
|
||||||
|
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||||
|
attn_metadata.slot_mapping)
|
||||||
|
else:
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions,
|
attn_metadata.decode.input_positions,
|
||||||
decode_q_pe.contiguous(),
|
decode_q_pe.contiguous(),
|
||||||
decode_k_pe,
|
decode_k_pe,
|
||||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
max_seq_len=attn_metadata.decode.max_seq_lens)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
|
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||||
|
if self.enable_graph_mode:
|
||||||
|
num_tokens = prefill_hs_or_q_c.shape[0]
|
||||||
|
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||||
|
-1)
|
||||||
|
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||||
|
# NOTE: When scaling not specified
|
||||||
|
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
|
||||||
|
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
|
||||||
|
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
|
||||||
|
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||||
|
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||||
|
prefill_k_pe)
|
||||||
|
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
|
||||||
|
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
|
||||||
|
else:
|
||||||
|
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||||
|
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||||
|
prefill_k_pe)
|
||||||
|
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||||
|
else:
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.prefill.input_positions,
|
attn_metadata.prefill.input_positions,
|
||||||
prefill_q_pe.contiguous(),
|
prefill_q_pe.contiguous(),
|
||||||
prefill_k_pe,
|
prefill_k_pe,
|
||||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
||||||
|
if self.enable_graph_mode:
|
||||||
if kv_cache.numel() > 0:
|
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||||
|
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
|
slots = attn_metadata.slot_mapping
|
||||||
|
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
||||||
|
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||||
|
num_tokens, self.num_kv_heads, -1),
|
||||||
|
value=prefill_k_pe,
|
||||||
|
key_cache=kv_cache[0],
|
||||||
|
value_cache=kv_cache[1],
|
||||||
|
slot_indices=slots)
|
||||||
|
elif kv_cache.numel() > 0:
|
||||||
key = torch.cat([
|
key = torch.cat([
|
||||||
k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe
|
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
||||||
|
k_pe
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
torch_npu._npu_reshape_and_cache_siso(
|
torch_npu._npu_reshape_and_cache_siso(
|
||||||
key=key,
|
key=key,
|
||||||
key_cache=kv_cache,
|
key_cache=kv_cache,
|
||||||
slot_indices=attn_metadata.slot_mapping.flatten())
|
slot_indices=attn_metadata.slot_mapping.flatten())
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[num_decode_tokens:] = self._forward_prefill(
|
output[num_decode_tokens:] = self._forward_prefill(
|
||||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
if self.running_in_graph:
|
||||||
|
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
||||||
|
decode_k_nope, decode_k_pe,
|
||||||
|
kv_cache, attn_metadata)
|
||||||
|
else:
|
||||||
output[:num_decode_tokens] = self._forward_decode(
|
output[:num_decode_tokens] = self._forward_decode(
|
||||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
|
||||||
|
kv_cache, attn_metadata)
|
||||||
return output_padded
|
return output_padded
|
||||||
@@ -15,18 +15,42 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.core.sched.utils import check_stop
|
||||||
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
|
||||||
class AscendScheduler(Scheduler):
|
class AscendScheduler(Scheduler):
|
||||||
"""This Scheduler extends vllm's original v1 scheduler
|
"""This Scheduler extends vllm's original v1 scheduler
|
||||||
with prefill-first scheduling strategy."""
|
with prefill-first scheduling strategy."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_config: KVCacheConfig,
|
||||||
|
structured_output_manager: StructuredOutputManager,
|
||||||
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
|
include_finished_set: bool = False,
|
||||||
|
log_stats: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(vllm_config, kv_cache_config,
|
||||||
|
structured_output_manager, mm_registry,
|
||||||
|
include_finished_set, log_stats)
|
||||||
|
self.scheduled_req_ids: set[str] = set()
|
||||||
|
self.running: list[Request] = []
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
if self.scheduler_config.chunked_prefill_enabled:
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
return super().schedule()
|
return super().schedule()
|
||||||
@@ -317,3 +341,175 @@ class AscendScheduler(Scheduler):
|
|||||||
return request.lora_request.long_lora_max_len
|
return request.lora_request.long_lora_max_len
|
||||||
else:
|
else:
|
||||||
return prompt_limit
|
return prompt_limit
|
||||||
|
|
||||||
|
def finish_requests(
|
||||||
|
self,
|
||||||
|
request_ids: Union[str, Iterable[str]],
|
||||||
|
finished_status: RequestStatus,
|
||||||
|
) -> None:
|
||||||
|
"""Handles the finish signal from outside the scheduler.
|
||||||
|
|
||||||
|
For example, the API server can abort a request when the client
|
||||||
|
disconnects.
|
||||||
|
"""
|
||||||
|
assert RequestStatus.is_finished(finished_status)
|
||||||
|
if isinstance(request_ids, str):
|
||||||
|
request_ids = (request_ids, )
|
||||||
|
else:
|
||||||
|
request_ids = set(request_ids)
|
||||||
|
|
||||||
|
for req_id in request_ids:
|
||||||
|
request = self.requests.get(req_id)
|
||||||
|
if request is None:
|
||||||
|
# Invalid request ID.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if request.status == RequestStatus.RUNNING:
|
||||||
|
self.running.remove(request)
|
||||||
|
self.scheduled_req_ids.discard(request.request_id)
|
||||||
|
else:
|
||||||
|
self.waiting.remove(request)
|
||||||
|
request.status = finished_status
|
||||||
|
self._free_request(request)
|
||||||
|
|
||||||
|
def update_from_output(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
model_runner_output: ModelRunnerOutput,
|
||||||
|
) -> EngineCoreOutputs:
|
||||||
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
|
spec_token_ids = model_runner_output.spec_token_ids
|
||||||
|
logprobs = model_runner_output.logprobs
|
||||||
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
|
||||||
|
new_running: list[Request] = []
|
||||||
|
outputs: list[EngineCoreOutput] = []
|
||||||
|
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||||
|
|
||||||
|
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||||
|
# loop can be a performance bottleneck. We should do our best to avoid
|
||||||
|
# expensive operations inside the loop.
|
||||||
|
for request in self.running:
|
||||||
|
req_id = request.request_id
|
||||||
|
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||||
|
if num_tokens_scheduled == 0:
|
||||||
|
# The request was not scheduled in this step.
|
||||||
|
new_running.append(request)
|
||||||
|
continue
|
||||||
|
|
||||||
|
req_index = model_runner_output.req_id_to_index[req_id]
|
||||||
|
generated_token_ids = sampled_token_ids[req_index]
|
||||||
|
|
||||||
|
scheduled_spec_token_ids = (
|
||||||
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||||
|
if scheduled_spec_token_ids:
|
||||||
|
# num_computed_tokens represents the number of tokens
|
||||||
|
# processed in the current step, considering scheduled
|
||||||
|
# tokens and rejections. If some tokens are rejected,
|
||||||
|
# num_computed_tokens is decreased by the number of rejected
|
||||||
|
# tokens, where is given by:
|
||||||
|
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||||
|
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||||
|
len(generated_token_ids))
|
||||||
|
request.num_computed_tokens -= num_tokens_rejected
|
||||||
|
spec_decoding_stats = self.make_spec_decoding_stats(
|
||||||
|
spec_decoding_stats,
|
||||||
|
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||||
|
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||||
|
|
||||||
|
cached_encoder_input_ids = (
|
||||||
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||||
|
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||||
|
if cached_encoder_input_ids:
|
||||||
|
for input_id in list(cached_encoder_input_ids):
|
||||||
|
mm_positions = request.mm_positions[input_id]
|
||||||
|
start_pos = mm_positions.offset
|
||||||
|
num_tokens = mm_positions.length
|
||||||
|
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||||
|
# The encoder output is already processed and stored
|
||||||
|
# in the decoder's KV cache.
|
||||||
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
|
request, input_id)
|
||||||
|
|
||||||
|
stopped = False
|
||||||
|
new_logprobs = None
|
||||||
|
new_token_ids = generated_token_ids
|
||||||
|
|
||||||
|
# Append generated tokens and check for stop. Note that if
|
||||||
|
# a request is still being prefilled, we expect the model runner
|
||||||
|
# to return empty token ids for the request.
|
||||||
|
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||||
|
request.append_output_token_ids(output_token_id)
|
||||||
|
|
||||||
|
# Check for stop and update request state.
|
||||||
|
# This must be called before we make the EngineCoreOutput.
|
||||||
|
stopped = check_stop(request, self.max_model_len)
|
||||||
|
if stopped:
|
||||||
|
self._free_request(request)
|
||||||
|
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract sample logprobs if needed.
|
||||||
|
if request.sampling_params.logprobs is not None and logprobs:
|
||||||
|
# NOTE: once we support N tokens per step (spec decode),
|
||||||
|
# the outer lists can be of length > 1.
|
||||||
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
|
if new_token_ids and request.use_structured_output:
|
||||||
|
# NOTE: structured_output_request
|
||||||
|
# should not be None if use_structured_output, we have
|
||||||
|
# check above, so safe to ignore type warning
|
||||||
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||||
|
req_id, new_token_ids)
|
||||||
|
|
||||||
|
# Add newly generated spec token ids to the request.
|
||||||
|
if spec_token_ids is not None:
|
||||||
|
if request.use_structured_output:
|
||||||
|
metadata = request.structured_output_request
|
||||||
|
assert metadata is not None and metadata.grammar is not None
|
||||||
|
# Needs to happen after new_token_ids are accepted.
|
||||||
|
request.spec_token_ids = metadata.grammar.validate_tokens(
|
||||||
|
spec_token_ids[req_index])
|
||||||
|
else:
|
||||||
|
request.spec_token_ids = spec_token_ids[req_index]
|
||||||
|
|
||||||
|
# Get prompt logprobs for this request.
|
||||||
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
|
if new_token_ids:
|
||||||
|
# Add EngineCoreOutput for this Request.
|
||||||
|
outputs.append(
|
||||||
|
EngineCoreOutput(
|
||||||
|
request_id=req_id,
|
||||||
|
new_token_ids=new_token_ids,
|
||||||
|
finish_reason=request.get_finished_reason(),
|
||||||
|
new_logprobs=new_logprobs,
|
||||||
|
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||||
|
stop_reason=request.stop_reason,
|
||||||
|
events=request.take_events()))
|
||||||
|
else:
|
||||||
|
# Invariant: EngineCore returns no partial prefill outputs.
|
||||||
|
assert not prompt_logprobs_tensors
|
||||||
|
|
||||||
|
self.scheduled_req_ids.remove(req_id)
|
||||||
|
if not stopped:
|
||||||
|
new_running.append(request)
|
||||||
|
|
||||||
|
# Return the cached request data to the queue so they can be reused.
|
||||||
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||||
|
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||||
|
# to _cached_reqs_data will cause a memory leak.
|
||||||
|
if req_data.req_id not in self.finished_req_ids:
|
||||||
|
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||||
|
|
||||||
|
self.running = new_running
|
||||||
|
engine_core_outputs = EngineCoreOutputs(
|
||||||
|
outputs=outputs,
|
||||||
|
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||||
|
)
|
||||||
|
if self.include_finished_set:
|
||||||
|
#TODO currently sending duplicates here, improve this
|
||||||
|
engine_core_outputs.finished_requests = (
|
||||||
|
scheduler_output.finished_req_ids | self.finished_req_ids)
|
||||||
|
|
||||||
|
return engine_core_outputs
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
import vllm.envs as envs
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
@@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
else:
|
else:
|
||||||
hidden_states_or_q_c = hidden_states
|
hidden_states_or_q_c = hidden_states
|
||||||
if self.enable_graph_mode:
|
if self.enable_graph_mode:
|
||||||
return self.mla_attn.impl.forward(self.mla_attn,
|
forward_kwargs = {}
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
output_shape = hidden_states.shape
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=hidden_states_or_q_c.dtype,
|
||||||
|
device=hidden_states_or_q_c.device)
|
||||||
|
forward_kwargs['output'] = output
|
||||||
|
|
||||||
|
output = self.mla_attn.impl.forward(self.mla_attn,
|
||||||
hidden_states_or_q_c,
|
hidden_states_or_q_c,
|
||||||
hidden_states, None, kv_cache,
|
hidden_states, None, kv_cache,
|
||||||
attn_metadata)
|
attn_metadata,
|
||||||
|
**forward_kwargs)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
output = output.view(-1, output_shape[-1])
|
||||||
|
return output
|
||||||
else:
|
else:
|
||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|||||||
@@ -153,9 +153,9 @@ class NPUPlatform(Platform):
|
|||||||
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
||||||
)
|
)
|
||||||
vllm_config.additional_config["enable_graph_mode"] = False
|
vllm_config.additional_config["enable_graph_mode"] = False
|
||||||
if enable_graph_mode and envs.VLLM_USE_V1:
|
if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"NPU graph mode is still experimental and not supported for V1 currently, "
|
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
||||||
"it has been disabled automatically.")
|
"it has been disabled automatically.")
|
||||||
vllm_config.additional_config["enable_graph_mode"] = False
|
vllm_config.additional_config["enable_graph_mode"] = False
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
@@ -117,6 +119,12 @@ class NPUModelRunner:
|
|||||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
self.graph_block_tables = np.zeros(
|
||||||
|
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
(self.model_config.max_model_len + self.block_size - 1) //
|
||||||
|
self.block_size),
|
||||||
|
dtype=np.int32)
|
||||||
|
|
||||||
# Model-related.
|
# Model-related.
|
||||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
vllm_config.parallel_config, LayerBlockType.attention)
|
vllm_config.parallel_config, LayerBlockType.attention)
|
||||||
@@ -307,6 +315,15 @@ class NPUModelRunner:
|
|||||||
self.attn_mask_len, self.dtype)
|
self.attn_mask_len, self.dtype)
|
||||||
|
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
self.enable_torchair_graph_mode = False
|
||||||
|
self.use_cached_npu_graph = False
|
||||||
|
additional_config = vllm_config.additional_config
|
||||||
|
if additional_config:
|
||||||
|
self.enable_torchair_graph_mode = additional_config.get(
|
||||||
|
"enable_graph_mode",
|
||||||
|
False) and self.vllm_config.model_config.use_mla
|
||||||
|
self.use_cached_npu_graph = additional_config.get(
|
||||||
|
"use_cached_npu_graph", False)
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
@@ -563,11 +580,19 @@ class NPUModelRunner:
|
|||||||
self.attn_mask = attn_mask
|
self.attn_mask = attn_mask
|
||||||
self.attn_state = attn_state # type: ignore
|
self.attn_state = attn_state # type: ignore
|
||||||
|
|
||||||
|
extra_builder_kwargs = {}
|
||||||
|
|
||||||
|
# Add graph_pad_size here
|
||||||
|
if self.enable_torchair_graph_mode:
|
||||||
|
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
|
||||||
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||||
|
|
||||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
common_prefix_len=None,
|
common_prefix_len=None,
|
||||||
|
**extra_builder_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare input_ids
|
# Prepare input_ids
|
||||||
@@ -582,14 +607,44 @@ class NPUModelRunner:
|
|||||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||||
|
|
||||||
|
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
padding = torch.zeros(graph_pad_size,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device)
|
||||||
|
input_ids = torch.cat([input_ids, padding])
|
||||||
|
positions = torch.cat([positions, padding])
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
|
model_kwargs = {}
|
||||||
|
if self.enable_torchair_graph_mode:
|
||||||
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
|
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
torch._dynamo.mark_static(input_ids)
|
||||||
|
torch._dynamo.mark_static(positions)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||||
|
for kv in self.kv_caches:
|
||||||
|
if isinstance(kv, tuple):
|
||||||
|
torch._dynamo.mark_static(kv[0])
|
||||||
|
torch._dynamo.mark_static(kv[1])
|
||||||
|
hidden_states = self.compile_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=None,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states[sample_indices]
|
return hidden_states[sample_indices]
|
||||||
@@ -879,6 +934,31 @@ class NPUModelRunner:
|
|||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
m.consumed_memory / float(2**30))
|
m.consumed_memory / float(2**30))
|
||||||
|
|
||||||
|
# adapter torch compile with npu_backend
|
||||||
|
if self.enable_torchair_graph_mode:
|
||||||
|
import torchair # type: ignore
|
||||||
|
from torchair import patch_for_hcom # type: ignore
|
||||||
|
|
||||||
|
patch_for_hcom()
|
||||||
|
config = torchair.CompilerConfig()
|
||||||
|
config.experimental_config.frozen_parameter = True
|
||||||
|
config.experimental_config.tiling_schedule_optimize = True
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
if not self.use_cached_npu_graph:
|
||||||
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
|
self.compile_model = torch.compile(
|
||||||
|
self.model,
|
||||||
|
dynamic=True,
|
||||||
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
backend=npu_backend)
|
||||||
|
else:
|
||||||
|
self.compile_model = torchair.inference.cache_compile(
|
||||||
|
self.model.forward,
|
||||||
|
dynamic=True,
|
||||||
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
config=config,
|
||||||
|
ge_cache=False)
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache based on `kv_cache_config`.
|
Initialize KV cache based on `kv_cache_config`.
|
||||||
@@ -909,6 +989,25 @@ class NPUModelRunner:
|
|||||||
num_blocks, kv_cache_spec.block_size,
|
num_blocks, kv_cache_spec.block_size,
|
||||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
|
if self.enable_torchair_graph_mode:
|
||||||
|
layer_kv_cache_nope = torch.zeros(
|
||||||
|
kv_cache_shape[:-1] +
|
||||||
|
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||||
|
dtype=self.dtype,
|
||||||
|
pin_memory=True,
|
||||||
|
device=self.device)
|
||||||
|
layer_kv_cache_pe = torch.zeros(
|
||||||
|
kv_cache_shape[:-1] +
|
||||||
|
(self.model_config.hf_text_config.qk_rope_head_dim,
|
||||||
|
),
|
||||||
|
dtype=self.dtype,
|
||||||
|
pin_memory=True,
|
||||||
|
device=self.device)
|
||||||
|
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||||
|
layer_kv_cache_pe)
|
||||||
|
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
|
||||||
|
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
|
||||||
|
else:
|
||||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user