feat: support torchair graph mode in v1 engine (#789)

### What this PR does / why we need it?
support torchair graph mode with v1 engine

---------

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-12 19:14:07 +08:00
committed by GitHub
parent 4a2505f81f
commit efabd722eb
5 changed files with 585 additions and 82 deletions

View File

@@ -1,11 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
import numpy as np
import torch import torch
import torch_npu import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
@@ -51,6 +54,7 @@ class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend""" """ Prefill Specific Metadata for Ascend"""
attn_mask: torch.Tensor attn_mask: torch.Tensor
query_lens: list[int] query_lens: list[int]
seq_lens: list[int]
context_lens: torch.Tensor context_lens: torch.Tensor
input_positions: torch.Tensor input_positions: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
@@ -66,6 +70,7 @@ class AscendMLADecodeMetadata:
block_table: torch.Tensor block_table: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
max_seq_lens: int max_seq_lens: int
seq_lens_list: list[int]
@dataclass @dataclass
@@ -195,11 +200,38 @@ class AscendMLAMetadataBuilder:
return modified_batch return modified_batch
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
if isinstance(self.runner.graph_block_tables, np.ndarray):
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
else:
graph_block_tables = self.runner.graph_block_tables.to(
device=block_tables.device, dtype=block_tables.dtype)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
graph_block_tables[:num_seqs, :
num_blocks] = block_tables[:num_seqs, :
num_blocks]
else:
graph_block_tables[:num_seqs, :
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables
def build(self, def build(self,
num_reqs: int, num_reqs: int,
num_actual_tokens: int, num_actual_tokens: int,
max_query_len: int, max_query_len: int,
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata: common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this # Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -230,6 +262,7 @@ class AscendMLAMetadataBuilder:
prefill_metadata = AscendMLAPrefillMetadata( prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask, attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:], query_lens=query_lens[tokens_start:],
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:], context_lens=seq_lens[tokens_start:],
input_positions=input_positions[tokens_start:], input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...], block_table=block_table[reqs_start:, ...],
@@ -238,12 +271,46 @@ class AscendMLAMetadataBuilder:
) )
decode_metadata = None decode_metadata = None
use_torchair_graph = graph_pad_size != -1
if self._num_decodes > 0: if self._num_decodes > 0:
max_seq_lens = seq_lens[:self._num_decodes].max().item() max_seq_lens = seq_lens[:self._num_decodes].max().item()
seq_lens = seq_lens[:self._num_decode_tokens]
input_positions = input_positions[:self._num_decode_tokens]
block_table = block_table[:self._num_decode_tokens, ...]
if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly:
num_seqs = len(seq_lens)
if graph_pad_size != 0:
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * graph_pad_size
else:
padded_seq_lens = seq_lens.tolist()
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
padding = torch.full((graph_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, padding])
block_table_padding = torch.zeros(
(graph_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs, block_table)
padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat([input_positions, padding_0])
decode_metadata = AscendMLADecodeMetadata( decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions[:self._num_decode_tokens], input_positions=input_positions,
block_table=block_table[:self._num_decode_tokens, ...], block_table=block_table,
seq_lens=seq_lens[:self._num_decode_tokens], seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens) max_seq_lens=max_seq_lens)
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
@@ -323,6 +390,8 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3 # latter has an additional parameter to control FA2 vs FA3
@@ -332,6 +401,12 @@ class AscendMLAImpl(MLAAttentionImpl):
# functools.partial(flash_attn_varlen_func, # functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version) # fa_version=self.vllm_flash_attn_version)
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -485,15 +560,55 @@ class AscendMLAImpl(MLAAttentionImpl):
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0] return self.o_proj(attn_output)[0]
def exec_kv(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):
B = hidden_states.shape[0]
N = self.num_kv_heads
S = 1
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA",
)
return k_pe, k_nope
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
assert decode_meta is not None assert decode_meta is not None
@@ -503,6 +618,36 @@ class AscendMLAImpl(MLAAttentionImpl):
[num_tokens, self.num_heads, self.kv_lora_rank], [num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype, dtype=q.dtype,
device=q.device) device=q.device)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
atten_mask=attn_metadata.attn_mask,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
)
else:
torch_npu._npu_paged_attention_mla( torch_npu._npu_paged_attention_mla(
query=q, query=q,
key_cache=kv_c_and_k_pe_cache, key_cache=kv_c_and_k_pe_cache,
@@ -519,85 +664,135 @@ class AscendMLAImpl(MLAAttentionImpl):
self, self,
layer: AttentionLayer, layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
# Inputs and outputs may be padded for CUDA graphs kv_c, k_pe = self.kv_a_proj_with_mqa(
output_padded = output hidden_states_or_kv_c_normed)[0].split(
output = output[:num_actual_toks, ...] [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
k_c_normed = k_c_normed[:num_actual_toks, ...] else:
k_pe = k_pe[:num_actual_toks, ...] kv_c_normed = hidden_states_or_kv_c_normed
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
if not self.running_in_graph:
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
if not self.running_in_graph:
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:] else:
decode_hs_or_q_c = hidden_states_or_q_c
if has_decode: if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c) self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
seq_len = self.rotary_emb.max_position_embeddings
cos = self.rotary_emb.cos_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(
dtype=decode_q_pe.dtype)
cos = cos[attn_metadata.decode.input_positions]
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, attn_metadata.decode.input_positions,
decode_q_pe.contiguous(), decode_q_pe.contiguous(),
decode_k_pe, decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens) max_seq_len=attn_metadata.decode.max_seq_lens)
if has_prefill: if has_prefill:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim) .view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.enable_graph_mode:
num_tokens = prefill_hs_or_q_c.shape[0]
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
-1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
# NOTE: When scaling not specified
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
else:
prefill_q_pe, prefill_k_pe = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions, attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_q_pe.contiguous(),
prefill_k_pe, prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens) max_seq_len=attn_metadata.prefill.max_seq_lens)
if self.enable_graph_mode:
if kv_cache.numel() > 0: if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
slots = attn_metadata.slot_mapping
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
num_tokens, self.num_kv_heads, -1),
value=prefill_k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
elif kv_cache.numel() > 0:
key = torch.cat([ key = torch.cat([
k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
k_pe
], ],
dim=2) dim=2)
torch_npu._npu_reshape_and_cache_siso( torch_npu._npu_reshape_and_cache_siso(
key=key, key=key,
key_cache=kv_cache, key_cache=kv_cache,
slot_indices=attn_metadata.slot_mapping.flatten()) slot_indices=attn_metadata.slot_mapping.flatten())
if has_prefill: if has_prefill:
output[num_decode_tokens:] = self._forward_prefill( output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata) attn_metadata)
if has_decode: if has_decode:
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
else:
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
return output_padded return output_padded

View File

@@ -15,18 +15,42 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from collections import deque from collections import deque
from typing import Iterable, Optional, Union
from vllm.config import VllmConfig
from vllm.logger import logger from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
class AscendScheduler(Scheduler): class AscendScheduler(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler """This Scheduler extends vllm's original v1 scheduler
with prefill-first scheduling strategy.""" with prefill-first scheduling strategy."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
super().__init__(vllm_config, kv_cache_config,
structured_output_manager, mm_registry,
include_finished_set, log_stats)
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled: if self.scheduler_config.chunked_prefill_enabled:
return super().schedule() return super().schedule()
@@ -317,3 +341,175 @@ class AscendScheduler(Scheduler):
return request.lora_request.long_lora_max_len return request.lora_request.long_lora_max_len
else: else:
return prompt_limit return prompt_limit
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
else:
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
self.scheduled_req_ids.discard(request.request_id)
else:
self.waiting.remove(request)
request.status = finished_status
self._free_request(request)
def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in self.running:
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
new_running.append(request)
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs

View File

@@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
import vllm.envs as envs
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
@@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
else: else:
hidden_states_or_q_c = hidden_states hidden_states_or_q_c = hidden_states
if self.enable_graph_mode: if self.enable_graph_mode:
return self.mla_attn.impl.forward(self.mla_attn, forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape
output = torch.empty(output_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
forward_kwargs['output'] = output
output = self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c, hidden_states_or_q_c,
hidden_states, None, kv_cache, hidden_states, None, kv_cache,
attn_metadata) attn_metadata,
**forward_kwargs)
if envs.VLLM_USE_V1:
output = output.view(-1, output_shape[-1])
return output
else: else:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

View File

@@ -153,9 +153,9 @@ class NPUPlatform(Platform):
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode" "enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
) )
vllm_config.additional_config["enable_graph_mode"] = False vllm_config.additional_config["enable_graph_mode"] = False
if enable_graph_mode and envs.VLLM_USE_V1: if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
logger.warning( logger.warning(
"NPU graph mode is still experimental and not supported for V1 currently, " "NPU graph mode is still experimental and not supported for V1 without mla currently, "
"it has been disabled automatically.") "it has been disabled automatically.")
vllm_config.additional_config["enable_graph_mode"] = False vllm_config.additional_config["enable_graph_mode"] = False

View File

@@ -63,6 +63,8 @@ if TYPE_CHECKING:
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
import vllm.envs as envs
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
@@ -117,6 +119,12 @@ class NPUModelRunner:
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
# Model-related. # Model-related.
self.num_attn_layers = self.model_config.get_num_layers_by_block_type( self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention) vllm_config.parallel_config, LayerBlockType.attention)
@@ -307,6 +315,15 @@ class NPUModelRunner:
self.attn_mask_len, self.dtype) self.attn_mask_len, self.dtype)
self.sampler = Sampler() self.sampler = Sampler()
self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False
additional_config = vllm_config.additional_config
if additional_config:
self.enable_torchair_graph_mode = additional_config.get(
"enable_graph_mode",
False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
@@ -563,11 +580,19 @@ class NPUModelRunner:
self.attn_mask = attn_mask self.attn_mask = attn_mask
self.attn_state = attn_state # type: ignore self.attn_state = attn_state # type: ignore
extra_builder_kwargs = {}
# Add graph_pad_size here
if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
attn_metadata = self.attn_metadata_builder.build( # type: ignore attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
common_prefix_len=None, common_prefix_len=None,
**extra_builder_kwargs,
) )
# Prepare input_ids # Prepare input_ids
@@ -582,14 +607,44 @@ class NPUModelRunner:
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:total_num_scheduled_tokens] input_ids = self.input_ids[:total_num_scheduled_tokens]
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
padding = torch.zeros(graph_pad_size,
dtype=input_ids.dtype,
device=input_ids.device)
input_ids = torch.cat([input_ids, padding])
positions = torch.cat([positions, padding])
# Run forward pass # Run forward pass
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
model_kwargs = {}
if self.enable_torchair_graph_mode:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
assert self.model is not None assert self.model is not None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=None, inputs_embeds=None,
**model_kwargs,
) )
return hidden_states[sample_indices] return hidden_states[sample_indices]
@@ -879,6 +934,31 @@ class NPUModelRunner:
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30)) m.consumed_memory / float(2**30))
# adapter torch compile with npu_backend
if self.enable_torchair_graph_mode:
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.compile_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
else:
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
@@ -909,6 +989,25 @@ class NPUModelRunner:
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
if self.enable_torchair_graph_mode:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=True,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim,
),
dtype=self.dtype,
pin_memory=True,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
else:
kv_caches[layer_name] = torch.zeros(kv_cache_shape, kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype, dtype=dtype,
device=self.device) device=self.device)