feat: support torchair graph mode in v1 engine (#789)
### What this PR does / why we need it? support torchair graph mode with v1 engine --------- Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@@ -51,6 +54,7 @@ class AscendMLAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
@@ -66,6 +70,7 @@ class AscendMLADecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -195,11 +200,38 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
else:
|
||||
graph_block_tables = self.runner.graph_block_tables.to(
|
||||
device=block_tables.device, dtype=block_tables.dtype)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[:num_seqs, :
|
||||
num_blocks] = block_tables[:num_seqs, :
|
||||
num_blocks]
|
||||
else:
|
||||
graph_block_tables[:num_seqs, :
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables
|
||||
|
||||
def build(self,
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata:
|
||||
common_prefix_len: Optional[int] = None,
|
||||
graph_pad_size: int = -1) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
@@ -230,6 +262,7 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
seq_lens=seq_lens,
|
||||
context_lens=seq_lens[tokens_start:],
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
@@ -238,12 +271,46 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
if self._num_decodes > 0:
|
||||
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
||||
seq_lens = seq_lens[:self._num_decode_tokens]
|
||||
input_positions = input_positions[:self._num_decode_tokens]
|
||||
block_table = block_table[:self._num_decode_tokens, ...]
|
||||
if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_seqs = len(seq_lens)
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
else:
|
||||
padded_seq_lens = seq_lens.tolist()
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((graph_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(graph_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs, block_table)
|
||||
padding_0 = torch.zeros(graph_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decode_tokens, ...],
|
||||
seq_lens=seq_lens[:self._num_decode_tokens],
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
@@ -323,6 +390,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
@@ -332,6 +401,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# functools.partial(flash_attn_varlen_func,
|
||||
# fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
@@ -485,15 +560,55 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
return self.o_proj(attn_output)[0]
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
|
||||
B = hidden_states.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA",
|
||||
)
|
||||
return k_pe, k_nope
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, N, D = x.shape
|
||||
S = 1
|
||||
x = x.view(B, N, S, D)
|
||||
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
||||
return x.view(B, N, D)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
|
||||
@@ -503,101 +618,181 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
if self.running_in_graph:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=decode_meta.block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||
)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if not self.running_in_graph:
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
||||
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
||||
if not self.running_in_graph:
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
else:
|
||||
decode_hs_or_q_c = hidden_states_or_q_c
|
||||
if has_decode:
|
||||
decode_k_nope = None
|
||||
assert attn_metadata.decode is not None
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions,
|
||||
decode_q_pe.contiguous(),
|
||||
decode_k_pe,
|
||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
||||
|
||||
if self.running_in_graph:
|
||||
seq_len = self.rotary_emb.max_position_embeddings
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||
dtype=decode_q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||
dtype=decode_q_pe.dtype)
|
||||
cos = cos[attn_metadata.decode.input_positions]
|
||||
sin = sin[attn_metadata.decode.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_k_pe, decode_k_nope = self.exec_kv(
|
||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||
attn_metadata.slot_mapping)
|
||||
else:
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions,
|
||||
decode_q_pe.contiguous(),
|
||||
decode_k_pe,
|
||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(),
|
||||
prefill_k_pe,
|
||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
if self.enable_graph_mode:
|
||||
num_tokens = prefill_hs_or_q_c.shape[0]
|
||||
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
||||
-1)
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
# NOTE: When scaling not specified
|
||||
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
|
||||
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
|
||||
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
|
||||
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
|
||||
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
prefill_q_pe, prefill_k_pe = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
||||
else:
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(),
|
||||
prefill_k_pe,
|
||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
||||
if self.enable_graph_mode:
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||
num_tokens, self.num_kv_heads, -1),
|
||||
value=prefill_k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
elif kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe
|
||||
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
||||
k_pe
|
||||
],
|
||||
dim=2)
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output_padded
|
||||
if self.running_in_graph:
|
||||
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
||||
decode_k_nope, decode_k_pe,
|
||||
kv_cache, attn_metadata)
|
||||
else:
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
|
||||
kv_cache, attn_metadata)
|
||||
return output_padded
|
||||
@@ -15,18 +15,42 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from collections import deque
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class AscendScheduler(Scheduler):
|
||||
"""This Scheduler extends vllm's original v1 scheduler
|
||||
with prefill-first scheduling strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
super().__init__(vllm_config, kv_cache_config,
|
||||
structured_output_manager, mm_registry,
|
||||
include_finished_set, log_stats)
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
self.running: list[Request] = []
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
return super().schedule()
|
||||
@@ -317,3 +341,175 @@ class AscendScheduler(Scheduler):
|
||||
return request.lora_request.long_lora_max_len
|
||||
else:
|
||||
return prompt_limit
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: Union[str, Iterable[str]],
|
||||
finished_status: RequestStatus,
|
||||
) -> None:
|
||||
"""Handles the finish signal from outside the scheduler.
|
||||
|
||||
For example, the API server can abort a request when the client
|
||||
disconnects.
|
||||
"""
|
||||
assert RequestStatus.is_finished(finished_status)
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids, )
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
self.scheduled_req_ids.discard(request.request_id)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
|
||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||
# loop can be a performance bottleneck. We should do our best to avoid
|
||||
# expensive operations inside the loop.
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||
if num_tokens_scheduled == 0:
|
||||
# The request was not scheduled in this step.
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[req_index]
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if scheduled_spec_token_ids:
|
||||
# num_computed_tokens represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
# tokens and rejections. If some tokens are rejected,
|
||||
# num_computed_tokens is decreased by the number of rejected
|
||||
# tokens, where is given by:
|
||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens -= num_tokens_rejected
|
||||
spec_decoding_stats = self.make_spec_decoding_stats(
|
||||
spec_decoding_stats,
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||
if cached_encoder_input_ids:
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = generated_token_ids
|
||||
|
||||
# Append generated tokens and check for stop. Note that if
|
||||
# a request is still being prefilled, we expect the model runner
|
||||
# to return empty token ids for the request.
|
||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
||||
request.append_output_token_ids(output_token_id)
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = check_stop(request, self.max_model_len)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and request.use_structured_output:
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if request.use_structured_output:
|
||||
metadata = request.structured_output_request
|
||||
assert metadata is not None and metadata.grammar is not None
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens(
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
self.scheduled_req_ids.remove(req_id)
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||
# to _cached_reqs_data will cause a memory leak.
|
||||
if req_data.req_id not in self.finished_req_ids:
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
||||
)
|
||||
if self.include_finished_set:
|
||||
#TODO currently sending duplicates here, improve this
|
||||
engine_core_outputs.finished_requests = (
|
||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
@@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
@@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.enable_graph_mode:
|
||||
return self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata)
|
||||
forward_kwargs = {}
|
||||
if envs.VLLM_USE_V1:
|
||||
output_shape = hidden_states.shape
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states_or_q_c.dtype,
|
||||
device=hidden_states_or_q_c.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
output = self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
if envs.VLLM_USE_V1:
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
else:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
@@ -653,4 +666,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -153,9 +153,9 @@ class NPUPlatform(Platform):
|
||||
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
||||
)
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
if enable_graph_mode and envs.VLLM_USE_V1:
|
||||
if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
|
||||
logger.warning(
|
||||
"NPU graph mode is still experimental and not supported for V1 currently, "
|
||||
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
||||
"it has been disabled automatically.")
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
|
||||
|
||||
@@ -63,6 +63,8 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
@@ -117,6 +119,12 @@ class NPUModelRunner:
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||
vllm_config.parallel_config, LayerBlockType.attention)
|
||||
@@ -307,6 +315,15 @@ class NPUModelRunner:
|
||||
self.attn_mask_len, self.dtype)
|
||||
|
||||
self.sampler = Sampler()
|
||||
self.enable_torchair_graph_mode = False
|
||||
self.use_cached_npu_graph = False
|
||||
additional_config = vllm_config.additional_config
|
||||
if additional_config:
|
||||
self.enable_torchair_graph_mode = additional_config.get(
|
||||
"enable_graph_mode",
|
||||
False) and self.vllm_config.model_config.use_mla
|
||||
self.use_cached_npu_graph = additional_config.get(
|
||||
"use_cached_npu_graph", False)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@@ -563,11 +580,19 @@ class NPUModelRunner:
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
extra_builder_kwargs = {}
|
||||
|
||||
# Add graph_pad_size here
|
||||
if self.enable_torchair_graph_mode:
|
||||
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -582,15 +607,45 @@ class NPUModelRunner:
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
padding = torch.zeros(graph_pad_size,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device)
|
||||
input_ids = torch.cat([input_ids, padding])
|
||||
positions = torch.cat([positions, padding])
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
)
|
||||
model_kwargs = {}
|
||||
if self.enable_torchair_graph_mode:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
for kv in self.kv_caches:
|
||||
if isinstance(kv, tuple):
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return hidden_states[sample_indices]
|
||||
|
||||
@@ -879,6 +934,31 @@ class NPUModelRunner:
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.enable_torchair_graph_mode:
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.compile_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
else:
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@@ -909,10 +989,29 @@ class NPUModelRunner:
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||
if self.enable_torchair_graph_mode:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
|
||||
Reference in New Issue
Block a user