diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index caf0eae..ff6742d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,11 +1,14 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar +import numpy as np import torch import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -51,6 +54,7 @@ class AscendMLAPrefillMetadata: """ Prefill Specific Metadata for Ascend""" attn_mask: torch.Tensor query_lens: list[int] + seq_lens: list[int] context_lens: torch.Tensor input_positions: torch.Tensor block_table: torch.Tensor @@ -66,6 +70,7 @@ class AscendMLADecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor max_seq_lens: int + seq_lens_list: list[int] @dataclass @@ -195,11 +200,38 @@ class AscendMLAMetadataBuilder: return modified_batch + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + if isinstance(self.runner.graph_block_tables, np.ndarray): + graph_block_tables = torch.zeros((max_batch_size, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + else: + graph_block_tables = self.runner.graph_block_tables.to( + device=block_tables.device, dtype=block_tables.dtype) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: Optional[int] = None) -> AscendMLAMetadata: + common_prefix_len: Optional[int] = None, + graph_pad_size: int = -1) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -230,6 +262,7 @@ class AscendMLAMetadataBuilder: prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], + seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], @@ -238,12 +271,46 @@ class AscendMLAMetadataBuilder: ) decode_metadata = None + use_torchair_graph = graph_pad_size != -1 if self._num_decodes > 0: max_seq_lens = seq_lens[:self._num_decodes].max().item() + seq_lens = seq_lens[:self._num_decode_tokens] + input_positions = input_positions[:self._num_decode_tokens] + block_table = block_table[:self._num_decode_tokens, ...] + if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly: + num_seqs = len(seq_lens) + if graph_pad_size != 0: + pad_value = 1 + padded_seq_lens = seq_lens.tolist() + [pad_value + ] * graph_pad_size + else: + padded_seq_lens = seq_lens.tolist() + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)) + padding = torch.full((graph_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, padding]) + block_table_padding = torch.zeros( + (graph_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_seqs, block_table) + padding_0 = torch.zeros(graph_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat([input_positions, padding_0]) + decode_metadata = AscendMLADecodeMetadata( - input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decode_tokens, ...], - seq_lens=seq_lens[:self._num_decode_tokens], + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), max_seq_lens=max_seq_lens) return self.metadata_cls( # type: ignore @@ -323,6 +390,8 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_b_proj = kv_b_proj self.o_proj = o_proj + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 @@ -332,6 +401,12 @@ class AscendMLAImpl(MLAAttentionImpl): # functools.partial(flash_attn_varlen_func, # fa_version=self.vllm_flash_attn_version) + self.enable_graph_mode = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_graph_mode = additional_config.get( + "enable_graph_mode", False) + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -485,15 +560,55 @@ class AscendMLAImpl(MLAAttentionImpl): [num_tokens, self.num_heads * self.v_head_dim]) return self.o_proj(attn_output)[0] + def exec_kv( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + + B = hidden_states.shape[0] + N = self.num_kv_heads + S = 1 + kv = self.kv_a_proj_with_mqa(hidden_states)[0] + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode="PA", + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - decode_meta = attn_metadata.decode assert decode_meta is not None @@ -503,101 +618,181 @@ class AscendMLAImpl(MLAAttentionImpl): [num_tokens, self.num_heads, self.kv_lora_rank], dtype=q.dtype, device=q.device) - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=kv_c_and_k_pe_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + if self.running_in_graph: + # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # shape of knope/k_pe for npu graph mode should be: + # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] + block_size = kv_c_and_k_pe_cache[0].shape[1] + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BNSD", + atten_mask=attn_metadata.attn_mask, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=decode_meta.block_table, + block_size=block_size, + actual_seq_lengths_kv=decode_meta.seq_lens_list, + ) + else: + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=kv_c_and_k_pe_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode.block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) return self._v_up_proj_and_o_proj(attn_output) def forward( self, layer: AttentionLayer, hidden_states_or_q_c: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn + hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: # Profiling run. return output - + self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly num_actual_toks = attn_metadata.num_actual_tokens - - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_toks, ...] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] - - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - + if k_pe is None and not self.running_in_graph: + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + else: + kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] - - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] - + if not self.running_in_graph: + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + kv_c_normed = kv_c_normed[:num_actual_toks, ...] + prefill_k_c_normed = kv_c_normed[num_decode_tokens:] + if not self.running_in_graph: + hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] + else: + decode_hs_or_q_c = hidden_states_or_q_c if has_decode: + decode_k_nope = None assert attn_metadata.decode is not None decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), - decode_k_pe, - max_seq_len=attn_metadata.decode.max_seq_lens) - + if self.running_in_graph: + seq_len = self.rotary_emb.max_position_embeddings + cos = self.rotary_emb.cos_cached[:seq_len].to( + dtype=decode_q_pe.dtype) + sin = self.rotary_emb.sin_cached[:seq_len].to( + dtype=decode_q_pe.dtype) + cos = cos[attn_metadata.decode.input_positions] + sin = sin[attn_metadata.decode.input_positions] + cos = cos[:, None, None, :] + sin = sin[:, None, None, :] + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + decode_k_pe, decode_k_nope = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) + else: + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + attn_metadata.decode.input_positions, + decode_q_pe.contiguous(), + decode_k_pe, + max_seq_len=attn_metadata.decode.max_seq_lens) if has_prefill: assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), - prefill_k_pe, - max_seq_len=attn_metadata.prefill.max_seq_lens) - - if kv_cache.numel() > 0: + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + if self.enable_graph_mode: + num_tokens = prefill_hs_or_q_c.shape[0] + prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, + -1) + if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': + # NOTE: When scaling not specified + ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape + prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1) + prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1) + prefill_q_pe, prefill_k_pe = self.rotary_emb( + attn_metadata.prefill.input_positions, prefill_q_pe, + prefill_k_pe) + prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape) + prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape) + else: + prefill_q_pe, prefill_k_pe = self.rotary_emb( + attn_metadata.prefill.input_positions, prefill_q_pe, + prefill_k_pe) + prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) + else: + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + attn_metadata.prefill.input_positions, + prefill_q_pe.contiguous(), + prefill_k_pe, + max_seq_len=attn_metadata.prefill.max_seq_lens) + if self.enable_graph_mode: + if len(kv_cache) > 0 and kv_cache[0].numel( + ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + slots = attn_metadata.slot_mapping + # NOTE: Separate the kv cache in advance to avoid OOM or other issues + torch_npu._npu_reshape_and_cache(key=kv_c_normed.view( + num_tokens, self.num_kv_heads, -1), + value=prefill_k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slots) + elif kv_cache.numel() > 0: key = torch.cat([ - k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe + kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), + k_pe ], dim=2) torch_npu._npu_reshape_and_cache_siso( key=key, key_cache=kv_cache, slot_indices=attn_metadata.slot_mapping.flatten()) - if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) - if has_decode: - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - - return output_padded + if self.running_in_graph: + return self._forward_decode(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe, + kv_cache, attn_metadata) + else: + output[:num_decode_tokens] = self._forward_decode( + decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, + kv_cache, attn_metadata) + return output_padded \ No newline at end of file diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index cdcd58b..3caa401 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -15,18 +15,42 @@ # This file is a part of the vllm-ascend project. # from collections import deque +from typing import Iterable, Optional, Union +from vllm.config import VllmConfig from vllm.logger import logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.utils import cdiv from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.core.sched.utils import check_stop +from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler with prefill-first scheduling strategy.""" + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, mm_registry, + include_finished_set, log_stats) + self.scheduled_req_ids: set[str] = set() + self.running: list[Request] = [] + def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: return super().schedule() @@ -317,3 +341,175 @@ class AscendScheduler(Scheduler): return request.lora_request.long_lora_max_len else: return prompt_limit + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + if request.status == RequestStatus.RUNNING: + self.running.remove(request) + self.scheduled_req_ids.discard(request.request_id) + else: + self.waiting.remove(request) + request.status = finished_status + self._free_request(request) + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] + spec_decoding_stats: Optional[SpecDecodingStats] = None + + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below + # loop can be a performance bottleneck. We should do our best to avoid + # expensive operations inside the loop. + for request in self.running: + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions.offset + num_tokens = mm_positions.length + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and request.use_structured_output: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + if request.use_structured_output: + metadata = request.structured_output_request + assert metadata is not None and metadata.grammar is not None + # Needs to happen after new_token_ids are accepted. + request.spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids: + # Add EngineCoreOutput for this Request. + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events())) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + self.scheduled_req_ids.remove(req_id) + if not stopped: + new_running.append(request) + + # Return the cached request data to the queue so they can be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) + + self.running = new_running + engine_core_outputs = EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(spec_decoding_stats), + ) + if self.include_finished_set: + #TODO currently sending duplicates here, improve this + engine_core_outputs.finished_requests = ( + scheduler_output.finished_req_ids | self.finished_req_ids) + + return engine_core_outputs diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index c46a3c2..19cfe71 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -31,6 +31,7 @@ from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist import torch_npu +import vllm.envs as envs from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata @@ -396,10 +397,22 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): else: hidden_states_or_q_c = hidden_states if self.enable_graph_mode: - return self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + if envs.VLLM_USE_V1: + output = output.view(-1, output_shape[-1]) + return output else: kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -653,4 +666,4 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): - pass \ No newline at end of file + pass diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 828d0e5..284abc6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -153,9 +153,9 @@ class NPUPlatform(Platform): "enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode" ) vllm_config.additional_config["enable_graph_mode"] = False - if enable_graph_mode and envs.VLLM_USE_V1: + if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE: logger.warning( - "NPU graph mode is still experimental and not supported for V1 currently, " + "NPU graph mode is still experimental and not supported for V1 without mla currently, " "it has been disabled automatically.") vllm_config.additional_config["enable_graph_mode"] = False diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 08475c4..c9870e0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -63,6 +63,8 @@ if TYPE_CHECKING: else: xgr = LazyLoader("xgr", globals(), "xgrammar") +import vllm.envs as envs + @dataclass class GraphCaptureContext: @@ -117,6 +119,12 @@ class NPUModelRunner: self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs + self.graph_block_tables = np.zeros( + (self.vllm_config.scheduler_config.max_num_seqs, + (self.model_config.max_model_len + self.block_size - 1) // + self.block_size), + dtype=np.int32) + # Model-related. self.num_attn_layers = self.model_config.get_num_layers_by_block_type( vllm_config.parallel_config, LayerBlockType.attention) @@ -307,6 +315,15 @@ class NPUModelRunner: self.attn_mask_len, self.dtype) self.sampler = Sampler() + self.enable_torchair_graph_mode = False + self.use_cached_npu_graph = False + additional_config = vllm_config.additional_config + if additional_config: + self.enable_torchair_graph_mode = additional_config.get( + "enable_graph_mode", + False) and self.vllm_config.model_config.use_mla + self.use_cached_npu_graph = additional_config.get( + "use_cached_npu_graph", False) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -563,11 +580,19 @@ class NPUModelRunner: self.attn_mask = attn_mask self.attn_state = attn_state # type: ignore + extra_builder_kwargs = {} + + # Add graph_pad_size here + if self.enable_torchair_graph_mode: + graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens) + extra_builder_kwargs['graph_pad_size'] = graph_pad_size + attn_metadata = self.attn_metadata_builder.build( # type: ignore num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, common_prefix_len=None, + **extra_builder_kwargs, ) # Prepare input_ids @@ -582,15 +607,45 @@ class NPUModelRunner: self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:total_num_scheduled_tokens] + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + padding = torch.zeros(graph_pad_size, + dtype=input_ids.dtype, + device=input_ids.device) + input_ids = torch.cat([input_ids, padding]) + positions = torch.cat([positions, padding]) + # Run forward pass with set_forward_context(attn_metadata, self.vllm_config): - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - ) + model_kwargs = {} + if self.enable_torchair_graph_mode: + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static(attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + for kv in self.kv_caches: + if isinstance(kv, tuple): + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + hidden_states = self.compile_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) + else: + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) return hidden_states[sample_indices] @@ -879,6 +934,31 @@ class NPUModelRunner: logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) + # adapter torch compile with npu_backend + if self.enable_torchair_graph_mode: + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.compile_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + else: + self.compile_model = torchair.inference.cache_compile( + self.model.forward, + dynamic=True, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -909,10 +989,29 @@ class NPUModelRunner: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - torch_npu.npu_format_cast(kv_caches[layer_name], 2) + if self.enable_torchair_graph_mode: + layer_kv_cache_nope = torch.zeros( + kv_cache_shape[:-1] + + (self.model_config.hf_text_config.kv_lora_rank, ), + dtype=self.dtype, + pin_memory=True, + device=self.device) + layer_kv_cache_pe = torch.zeros( + kv_cache_shape[:-1] + + (self.model_config.hf_text_config.qk_rope_head_dim, + ), + dtype=self.dtype, + pin_memory=True, + device=self.device) + kv_caches[layer_name] = (layer_kv_cache_nope, + layer_kv_cache_pe) + torch_npu.npu_format_cast(kv_caches[layer_name][0], 2) + torch_npu.npu_format_cast(kv_caches[layer_name][1], 2) + else: + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: # TODO: add new branches when introducing more types of # KV cache specs.