diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 62b1e1e..dce98d1 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,21 +17,29 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +import types from typing import Optional import torch +import torch.distributed as dist +import torch.nn as nn import torch_npu +import vllm.envs as envs_vllm from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import logger +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, check_torchair_cache_exist, register_torchair_model, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - maybe_converting_weight_acl_format) + is_310p, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): super().__init__(vllm_config, device) + ascend_config = get_ascend_config() + self.new_kv_cache_bytes = -1 + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore + self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph + self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes + if ascend_config.torchair_graph_config.graph_batch_sizes_init: + self.init_torchair_graph_batch_sizes() + + self.check_torchair_graph_batch_sizes() + + torch._dynamo.cache_size.config.cache_size_limit += len( + self.torchair_graph_batch_sizes) + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._logging.set_logs( + recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) + + self._check_batch_sizes_consistency() register_torchair_model() def _get_forward_metadata_across_dp_and_pad( @@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner): if self.new_kv_cache_bytes > 0: write_kv_cache_bytes_to_file(torch.distributed.get_rank(), self.new_kv_cache_bytes) + + def _use_aclgraph(self) -> bool: + return False + + def _check_batch_sizes_consistency(self) -> None: + if not dist.is_initialized(): + return + + local = torch.tensor(self.torchair_graph_batch_sizes, + device="cpu", + dtype=torch.int32) + gathered_graph_batch_size = local.clone() + dist.all_reduce(gathered_graph_batch_size, + group=get_dp_group().cpu_group) + expected = local * self.dp_size + + if not torch.equal(gathered_graph_batch_size, expected): + diff_idxs = (gathered_graph_batch_size != expected).nonzero( + as_tuple=False).flatten().tolist() + raise AssertionError( + f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" + f"Local (rank {self.dp_rank}): {local.tolist()}\n" + f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" + f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" + ) + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + if not with_prefill: + self.graph_pad_size = graph_pad_size + else: + super()._update_graph_pad_size(with_prefill, graph_pad_size) + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + """Override from NPUModelRunner to update input_ids and positions""" + input_ids, positions = super()._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) + + if not with_prefill: + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] + return input_ids, positions + + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + model_kwargs = { + "kv_caches": self.kv_caches, + "attn_metadata": attn_metadata + } + if not with_prefill: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model( + padded_num_tokens_across_dp) + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + else: + assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + return hidden_states + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + + if is_310p(): + # on 300I Duo platform, we need to patch broadcast. however, this patch will be + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. + from vllm_ascend.patch.platform.patch_common.patch_distributed import \ + communication_adaptation_310p + communication_adaptation_310p() + + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to + # disable it on 300I Duo platform now. + config.experimental_config.tiling_schedule_optimize = not is_310p() + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] + + def init_torchair_graph_batch_sizes(self): + start_graph_batch_size = 4 + tp_size = get_tensor_model_parallel_world_size() + + # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks + start_graph_batch_size = max(start_graph_batch_size, tp_size) + + while (start_graph_batch_size <= self.max_num_reqs): + self.torchair_graph_batch_sizes.append(start_graph_batch_size) + start_graph_batch_size *= 2 + + def select_torchair_padded_batch_size(self, batch_size: int): + for padded_batch_size in self.torchair_graph_batch_sizes: + if batch_size <= padded_batch_size: + # we treat batch_size as num of requests + return padded_batch_size + raise ValueError( + f"cur batch_size is invalid, torchair_graph_batch_sizes is " + f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." + ) + + def check_torchair_graph_batch_sizes(self): + # return graph_batch_sizes according to the max number of tokens + # first pad according to the number of requests + if len(self.torchair_graph_batch_sizes) == 0: + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + else: + self.torchair_graph_batch_sizes = sorted( + self.torchair_graph_batch_sizes) + while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: + self.torchair_graph_batch_sizes.pop() + if len(self.torchair_graph_batch_sizes) == 0: + logger.warning( + "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" + ) + self.torchair_graph_batch_sizes = [1, self.max_num_reqs] + if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: + self.torchair_graph_batch_sizes.append(self.max_num_reqs) + + # padded max number tokens = max_num_req * decode_token_per_req + self.torchair_graph_batch_sizes = [ + graph_batch_size * self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] + + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` + tp_size = self.parallel_config.tensor_parallel_size + if self.parallel_config.enable_expert_parallel: + new_graph_batch_sizes = [] + for graph_batch_size in self.torchair_graph_batch_sizes: + cur_graph_batch_size = (graph_batch_size + tp_size - + 1) // tp_size * tp_size + if cur_graph_batch_size not in new_graph_batch_sizes and \ + cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: + new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) + self.torchair_graph_batch_sizes = new_graph_batch_sizes + + def _build_drafter_prepare_inputs_torchair_param(self): + return True diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f4169cf..b55cc13 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -22,7 +22,6 @@ import gc import math import os import time -import types from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast @@ -39,7 +38,6 @@ from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -108,7 +106,6 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") import torch_npu -import vllm.envs as envs_vllm import vllm_ascend.envs as envs_ascend @@ -341,11 +338,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): pin_memory=True) self.seq_lens_np = self.seq_lens_cpu.numpy() - self.use_aclgraph = ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and self.compilation_config.level == CompilationLevel.PIECEWISE - and not self.model_config.enforce_eager - and not ascend_config.torchair_graph_config.enabled) + self.use_aclgraph = self._use_aclgraph() self.aclgraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) @@ -357,31 +350,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None - self.new_kv_cache_bytes = -1 - self.torchair_compiled_model = None # type: ignore - self.torchair_compiled_models = {} # type: ignore - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph - self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes - if ascend_config.torchair_graph_config.graph_batch_sizes_init: - self.init_torchair_graph_batch_sizes() - - self.check_torchair_graph_batch_sizes() - - # graph_block_tables shape: [num_request, cell(max_model_len / block_size)] - self.graph_block_tables = np.zeros( - (self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req, - (self.model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) - - torch._dynamo.cache_size.config.cache_size_limit += len( - self.torchair_graph_batch_sizes) - torch._dynamo.config.capture_dynamic_output_shape_ops = True - torch._logging.set_logs( - recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) - - self.check_batch_sizes_consistency() # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False @@ -400,27 +368,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.moe_comm_method = AllGatherCommImpl - def check_batch_sizes_consistency(self) -> None: - if not dist.is_initialized(): - return - - local = torch.tensor(self.torchair_graph_batch_sizes, - device="cpu", - dtype=torch.int32) - gathered_graph_batch_size = local.clone() - dist.all_reduce(gathered_graph_batch_size, - group=get_dp_group().cpu_group) - expected = local * self.dp_size - - if not torch.equal(gathered_graph_batch_size, expected): - diff_idxs = (gathered_graph_batch_size != expected).nonzero( - as_tuple=False).flatten().tolist() - raise AssertionError( - f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" - f"Local (rank {self.dp_rank}): {local.tolist()}\n" - f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" - f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" - ) + def _use_aclgraph(self) -> bool: + return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. @@ -1047,14 +996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def _process_reqs( + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, - AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray, - Optional[set[str]], Optional[set[str]]]: + AscendTorchairMetadata], torch.Tensor, np.ndarray, int, + torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, + Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor]]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -1103,9 +1053,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) - logits_indices = cu_num_tokens - 1 - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] @@ -1118,7 +1065,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) - if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], @@ -1127,7 +1073,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - positions = self.positions[:num_input_tokens] + self.query_lens = torch.from_numpy(num_scheduled_tokens) self.seq_lens_np[:num_reqs] = ( @@ -1145,34 +1091,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - ascend_config = get_ascend_config() - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): - attn_state = AscendAttentionState.PrefillNoCache - # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. - elif np.all(num_scheduled_tokens == 1): - attn_state = AscendAttentionState.DecodeOnly - if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': - # SpecDecoding now supports seq_len=1 and seq_len=2 - # In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1 - attn_state = AscendAttentionState.SpecDecoding - # Speculative decoding. - elif np.all(num_valid_tokens == 1): - if self.use_eagle: - attn_state = AscendAttentionState.ChunkedPrefill - else: - attn_state = AscendAttentionState.SpecDecoding - # splitfuse - elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: - attn_state = AscendAttentionState.ChunkedPrefill - else: - attn_state = AscendAttentionState.PrefillCacheHit + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) self.attn_mask = self._make_attention_mask( seq_lens=seq_lens, query_lens=num_scheduled_tokens, - position=positions, + position=self.positions[:num_input_tokens], attn_state=attn_state) self.attn_state = attn_state # type: ignore @@ -1191,8 +1116,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] - is_only_prefill = bool(np.all(num_valid_tokens != 1)) - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) @@ -1202,10 +1125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - if self.torchair_graph_enabled and not with_prefill: - self.graph_pad_size = padded_num_tokens_across_dp - else: - self.graph_pad_size = -1 + self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1221,7 +1141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, enable_dbo_across_dp=enable_dbo, - is_only_prefill=is_only_prefill, + is_only_prefill=bool(np.all(num_valid_tokens != 1)), max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, @@ -1248,10 +1168,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1273,12 +1190,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] - - if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_num_tokens_across_dp] - positions = self.positions[:padded_num_tokens_across_dp] + positions = self.positions[:num_input_tokens] + input_ids, positions = self._update_input_ids_and_positions( + input_ids, positions, num_input_tokens, with_prefill, + padded_num_tokens_across_dp) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1293,8 +1208,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - moe_comm_method = self.moe_comm_method - # NOTE: Currently this padding logic is really messy, # MC2 may not be available in eager mode # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP @@ -1303,52 +1216,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: num_input_tokens = padded_num_tokens_across_dp - # Run forward pass - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, - reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method(self.device, self.dtype, - self.model_config.hf_config), - num_actual_tokens=total_num_scheduled_tokens): - with ProfileExecuteDuration().capture_async("forward"): - self.maybe_setup_kv_connector(scheduler_output) - model_kwargs = {} - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled and not with_prefill: - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model( - padded_num_tokens_across_dp) - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - else: - assert self.model is not None - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = self.get_finished_kv_transfer( - scheduler_output) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1358,6 +1225,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. spec_decode_metadata = None + logits_indices = torch.from_numpy(cu_num_tokens - 1).to( + self.device, non_blocking=True) else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all @@ -1372,13 +1241,61 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - aux_hidden_states = None - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = hidden_states + return (attn_metadata, positions, num_scheduled_tokens, + num_input_tokens, num_tokens_across_dp, + padded_num_tokens_across_dp, logits_indices, + spec_decode_metadata, input_ids, inputs_embeds, + intermediate_tensors) - return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens, finished_sending, finished_recving) + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, + padded_num_tokens_across_dp, + input_ids, positions, + intermediate_tensors, + inputs_embeds): + assert self.model is not None + maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def _build_attn_state(self, num_reqs, num_scheduled_tokens, + num_valid_tokens): + ascend_config = get_ascend_config() + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillNoCache + # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': + # SpecDecoding now supports seq_len=1 and seq_len=2 + # In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1 + attn_state = AscendAttentionState.SpecDecoding + # Speculative decoding. + elif np.all(num_valid_tokens == 1): + if self.use_eagle: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.SpecDecoding + # splitfuse + elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.PrefillCacheHit + return attn_state + + def _update_graph_pad_size(self, with_prefill, graph_pad_size): + self.graph_pad_size = -1 + + def _update_input_ids_and_positions(self, input_ids, positions, + num_input_tokens, with_prefill, + padded_num_tokens_across_dp): + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + return input_ids, positions def _get_cumsum_and_arange( self, @@ -1623,8 +1540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - with ProfileExecuteDuration().capture_async( - "prepare input and forward"): + with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1634,11 +1550,41 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) - (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens_np, finished_sending, - finished_recving) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + (attn_metadata, positions, num_scheduled_tokens_np, + num_input_tokens, num_tokens_across_dp, + padded_num_tokens_across_dp, logits_indices, spec_decode_metadata, + input_ids, inputs_embeds, + intermediate_tensors) = (self._prepare_inputs( + scheduler_output, intermediate_tensors)) + + # Run forward pass + with ProfileExecuteDuration().capture_async("forward"): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=self.with_prefill, + reserved_mc2_mask=self.reserved_mc2_mask, + moe_comm_method=self.moe_comm_method( + self.device, self.dtype, self.model_config.hf_config), + num_actual_tokens=scheduler_output. + total_num_scheduled_tokens): + self.maybe_setup_kv_connector(scheduler_output) + + hidden_states = self._generate_process_reqs_hidden_states( + attn_metadata, self.with_prefill, + padded_num_tokens_across_dp, input_ids, positions, + intermediate_tensors, inputs_embeds) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = self.get_finished_kv_transfer( + scheduler_output) + + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states + kv_connector_output = None if finished_sending is not None or finished_recving is not None: kv_connector_output = KVConnectorOutput( @@ -1667,10 +1613,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): logits = None else: if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, - finished_sending, finished_recving, - kv_connector_output) + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: @@ -1746,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], + hidden_states[:scheduler_output.total_num_scheduled_tokens], scheduler_output, ) @@ -1796,7 +1743,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output, spec_decode_metadata, positions, - num_scheduled_tokens, + scheduler_output.total_num_scheduled_tokens, hidden_states, attn_metadata, aux_hidden_states, @@ -2191,72 +2138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - - patch_for_hcom() - - if is_310p(): - # on 300I Duo platform, we need to patch broadcast. however, this patch will be - # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. - from vllm_ascend.patch.platform.patch_common.patch_distributed import \ - communication_adaptation_310p - communication_adaptation_310p() - - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to - # disable it on 300I Duo platform now. - config.experimental_config.tiling_schedule_optimize = not is_310p() - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor @@ -2707,7 +2588,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens], attn_metadata.slot_mapping[:num_scheduled_tokens], - is_torchair_graph=self.torchair_graph_enabled, + is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(), ) draft_token_ids = self.drafter.propose( @@ -2818,72 +2699,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): return prompt_logprobs_dict - def init_torchair_graph_batch_sizes(self): - start_graph_batch_size = 4 - tp_size = get_tensor_model_parallel_world_size() - - # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks - start_graph_batch_size = max(start_graph_batch_size, tp_size) - - while (start_graph_batch_size <= self.max_num_reqs): - self.torchair_graph_batch_sizes.append(start_graph_batch_size) - start_graph_batch_size *= 2 - - def select_torchair_padded_batch_size(self, batch_size: int): - for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size: - # we treat batch_size as num of requests - return padded_batch_size - raise ValueError( - f"cur batch_size is invalid, torchair_graph_batch_sizes is " - f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) - - def check_torchair_graph_batch_sizes(self): - # return graph_batch_sizes according to the max number of tokens - # first pad according to the number of requests - if len(self.torchair_graph_batch_sizes) == 0: - self.torchair_graph_batch_sizes = [1, self.max_num_reqs] - else: - self.torchair_graph_batch_sizes = sorted( - self.torchair_graph_batch_sizes) - while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: - self.torchair_graph_batch_sizes.pop() - if len(self.torchair_graph_batch_sizes) == 0: - logger.warning( - "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" - ) - self.torchair_graph_batch_sizes = [1, self.max_num_reqs] - if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: - self.torchair_graph_batch_sizes.append(self.max_num_reqs) - - # padded max number tokens = max_num_req * decode_token_per_req - self.torchair_graph_batch_sizes = [ - graph_batch_size * self.decode_token_per_req - for graph_batch_size in self.torchair_graph_batch_sizes - ] - - # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` - tp_size = self.parallel_config.tensor_parallel_size - if self.parallel_config.enable_expert_parallel: - new_graph_batch_sizes = [] - for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = (graph_batch_size + tp_size - - 1) // tp_size * tp_size - if cur_graph_batch_size not in new_graph_batch_sizes and \ - cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: - new_graph_batch_sizes.append(cur_graph_batch_size) - elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ - and self.decode_token_per_req > 1: - logger.warning( - f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", - f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." - ) - self.torchair_graph_batch_sizes = new_graph_batch_sizes - def get_supported_pooling_tasks(self): model = self.get_model() if not is_pooling_model(model): return [] return list(model.pooler.get_supported_tasks()) + + def _build_drafter_prepare_inputs_torchair_param(self): + return False diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 9493143..61320fa 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -48,6 +48,8 @@ class MtpProposer: device=self.runner.device) self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore + self.torchair_graph_enabled = get_ascend_config( + ).torchair_graph_config.enabled @staticmethod def prepare_inputs( @@ -136,7 +138,7 @@ class MtpProposer: self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - if token_indices is not None and self.runner.torchair_graph_enabled: + if token_indices is not None and self.torchair_graph_enabled: last_token_indices = token_indices self.input_ids[last_token_indices] = next_token_ids @@ -154,7 +156,7 @@ class MtpProposer: # input_batch=self.runner.input_batch, # scheduler_output=self.runner.scheduler_output, # ) - is_running_torchair = self.runner.torchair_graph_enabled and \ + is_running_torchair = self.torchair_graph_enabled and \ not self.runner.with_prefill if is_running_torchair: @@ -193,7 +195,7 @@ class MtpProposer: attn_metadata.prefill.input_positions = target_positions attn_metadata.prefill.seq_lens = seq_lens - if not self.runner.torchair_graph_enabled: + if not self.torchair_graph_enabled: # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later (num_input_tokens, num_tokens_across_dp, with_prefill, @@ -216,7 +218,7 @@ class MtpProposer: with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata - if self.runner.torchair_graph_enabled: + if self.torchair_graph_enabled: model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] if is_running_torchair: torchair_compiled_model = self._get_torchair_lazy_compiled_model( @@ -280,12 +282,12 @@ class MtpProposer: skip_attn: bool = False, num_reqs: int = 0, num_tokens_across_dp=None) -> None: - if not self.runner.torchair_graph_enabled: + if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, _) = self.runner._get_forward_metadata_across_dp_and_pad( num_tokens, with_prefill, False) - is_running_torchair = self.runner.torchair_graph_enabled and \ + is_running_torchair = self.torchair_graph_enabled and \ not with_prefill if is_running_torchair: