diff --git a/vllm_ascend/_310p/model_runner_310p.py b/vllm_ascend/_310p/model_runner_310p.py index 3dceff55..d4a15785 100644 --- a/vllm_ascend/_310p/model_runner_310p.py +++ b/vllm_ascend/_310p/model_runner_310p.py @@ -17,6 +17,8 @@ from __future__ import annotations +from contextlib import contextmanager, nullcontext + import numpy as np import torch import torch_npu @@ -24,14 +26,151 @@ from vllm.logger import logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +_NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN = 1 + class NPUModelRunner310(NPUModelRunner): + # Inherited from parent runner; annotated here to satisfy strict type checks. + uniform_decode_query_len: int + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._acl_format = ACL_FORMAT_FRACTAL_NZ + if self.speculative_config is not None and self.speculative_config.method == "ngram": + # 310P ngram requires decode-only graph shapes to be built with q_len=1. + # Keep dispatcher's internal query_len in sync to avoid key-init assert. + self.cudagraph_dispatcher.uniform_decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN + + @contextmanager + def temporary_modify_uniform_decode_query_len(self): + # This is only needed for the 310P ngram path where dispatcher uses q_len=1 + # while runner's default uniform_decode_query_len remains 1 + num_spec_tokens. + # TODO: remove this temporary override after upstream supports independent + # decode capture query_len for backend-specific paths. + if self.speculative_config is None or self.speculative_config.method != "ngram": + yield + return + + original_uniform_decode_query_len = self.uniform_decode_query_len + self.uniform_decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN + try: + yield + finally: + self.uniform_decode_query_len = original_uniform_decode_query_len + + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + num_reqs: int, + num_scheduled_tokens_np: np.ndarray, + max_num_scheduled_tokens: int, + use_cascade_attn: bool, + allow_microbatching: bool = False, + force_eager: bool = False, + force_uniform_decode: bool | None = None, + force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, + num_encoder_reqs: int = 0, + ): + if self.attn_state in (AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit): + force_eager = True + + if force_uniform_decode is None and self.attn_state == AscendAttentionState.DecodeOnly: + decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN + if ( + max_num_scheduled_tokens == decode_query_len + and num_tokens == max_num_scheduled_tokens * num_reqs + and np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] > 0) + ): + # Respect explicit caller override: only force when unset. + force_uniform_decode = True + + return super()._determine_batch_execution_and_padding( + num_tokens=num_tokens, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=use_cascade_attn, + allow_microbatching=allow_microbatching, + force_eager=force_eager, + force_uniform_decode=force_uniform_decode, + force_has_lora=force_has_lora, + force_num_active_loras=force_num_active_loras, + num_encoder_reqs=num_encoder_reqs, + ) + + def _pad_query_start_loc_for_fia(self, num_tokens_padded: int, num_reqs_padded: int, num_reqs: int) -> int: + # Keep this aligned with the dispatcher because batch_desc.num_reqs is + # generated by dispatcher._create_padded_batch_descriptor(). + # For 310P ngram we intentionally set dispatcher q_len=1, while runner's + # default uniform_decode_query_len may remain 1 + num_spec_tokens. + uniform_decode_query_len = self.cudagraph_dispatcher.uniform_decode_query_len + + if num_tokens_padded == num_reqs_padded * uniform_decode_query_len: + # Uniform-batch case: num_reqs must be no greater than num_reqs_padded + assert num_reqs <= num_reqs_padded + + last_loc = self.query_start_loc.np[num_reqs] + self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1] = ( + self.arange_np[1 : num_reqs_padded + 1 - num_reqs] * uniform_decode_query_len + last_loc + ) + else: + # Mixed-batch case: num_reqs must equal num_reqs_padded + assert num_reqs == num_reqs_padded + + # Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly + self.query_start_loc.np[num_reqs_padded + 1] = num_tokens_padded + num_reqs_padded = num_reqs_padded + 1 + + self.query_start_loc.copy_to_gpu() + return num_reqs_padded + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + cudagraph_runtime_mode=None, + force_attention: bool = False, + uniform_decode: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + is_graph_capturing: bool = False, + num_active_loras: int = 0, + ): + temporary_context = self.temporary_modify_uniform_decode_query_len() if uniform_decode else nullcontext() + with temporary_context: + return super()._dummy_run( + num_tokens=num_tokens, + with_prefill=with_prefill, + cudagraph_runtime_mode=cudagraph_runtime_mode, + force_attention=force_attention, + uniform_decode=uniform_decode, + is_profile=is_profile, + create_mixed_batch=create_mixed_batch, + allow_microbatching=allow_microbatching, + skip_eplb=skip_eplb, + remove_lora=remove_lora, + is_graph_capturing=is_graph_capturing, + num_active_loras=num_active_loras, + ) + + def _check_and_update_cudagraph_mode( + self, + attention_backends, + kv_cache_groups, + ) -> None: + # 910B does not need this branch because runner/dispatcher query_len are + # naturally consistent there. 310P ngram needs temporary alignment. + with self.temporary_modify_uniform_decode_query_len(): + super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups) def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """