refact runner model v1 (#2461)

refact model runner v1

### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1

- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-08-21 08:54:57 +08:00
committed by GitHub
parent 1de16ead8e
commit 0dca4c6dbd
3 changed files with 368 additions and 307 deletions

View File

@@ -17,21 +17,29 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
import types
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
check_torchair_cache_exist,
register_torchair_model,
write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format)
is_310p, maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)
ascend_config = get_ascend_config()
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self._check_batch_sizes_consistency()
register_torchair_model()
def _get_forward_metadata_across_dp_and_pad(
@@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner):
if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
def _use_aclgraph(self) -> bool:
return False
def _check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
if not with_prefill:
self.graph_pad_size = graph_pad_size
else:
super()._update_graph_pad_size(with_prefill, graph_pad_size)
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
"""Override from NPUModelRunner to update input_ids and positions"""
input_ids, positions = super()._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
if not with_prefill:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
return input_ids, positions
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
model_kwargs = {
"kv_caches": self.kv_caches,
"attn_metadata": attn_metadata
}
if not with_prefill:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return hidden_states
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def _build_drafter_prepare_inputs_torchair_param(self):
return True

View File

@@ -22,7 +22,6 @@ import gc
import math
import os
import time
import types
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
@@ -39,7 +38,6 @@ from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -108,7 +106,6 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
import torch_npu
import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend
@@ -341,11 +338,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.use_aclgraph = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.level == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager
and not ascend_config.torchair_graph_config.enabled)
self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
@@ -357,31 +350,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
self.graph_block_tables = np.zeros(
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self.check_batch_sizes_consistency()
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
@@ -400,27 +368,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.moe_comm_method = AllGatherCommImpl
def check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove finished requests from the cached states.
@@ -1047,14 +996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _process_reqs(
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
Optional[set[str]], Optional[set[str]]]:
AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -1103,9 +1053,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
logits_indices = cu_num_tokens - 1
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
positions_np = self.positions_np[:total_num_scheduled_tokens]
@@ -1118,7 +1065,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
@@ -1127,7 +1073,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = (
@@ -1145,34 +1091,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
ascend_config = get_ascend_config()
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions,
position=self.positions[:num_input_tokens],
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
@@ -1191,8 +1116,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
is_only_prefill = bool(np.all(num_valid_tokens != 1))
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
@@ -1202,10 +1125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
total_num_scheduled_tokens, with_prefill, enable_dbo)
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
if self.torchair_graph_enabled and not with_prefill:
self.graph_pad_size = padded_num_tokens_across_dp
else:
self.graph_pad_size = -1
self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1221,7 +1141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=is_only_prefill,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req,
@@ -1248,10 +1168,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
@@ -1273,12 +1190,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the ACL graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
if self.torchair_graph_enabled and not with_prefill:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
positions = self.positions[:num_input_tokens]
input_ids, positions = self._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1293,8 +1208,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
moe_comm_method = self.moe_comm_method
# NOTE: Currently this padding logic is really messy,
# MC2 may not be available in eager mode
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
@@ -1303,52 +1216,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
num_input_tokens = padded_num_tokens_across_dp
# Run forward pass
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(self.device, self.dtype,
self.model_config.hf_config),
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1358,6 +1225,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
spec_decode_metadata = None
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True)
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
@@ -1372,13 +1241,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
return (attn_metadata, positions, num_scheduled_tokens,
num_input_tokens, num_tokens_across_dp,
padded_num_tokens_across_dp, logits_indices,
spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors)
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens, finished_sending, finished_recving)
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
assert self.model is not None
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
return attn_state
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
self.graph_pad_size = -1
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
return input_ids, positions
def _get_cumsum_and_arange(
self,
@@ -1623,8 +1540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
with ProfileExecuteDuration().capture_async(
"prepare input and forward"):
with ProfileExecuteDuration().capture_async("prepare input"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
@@ -1634,11 +1550,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens_np, finished_sending,
finished_recving) = (self._process_reqs(scheduler_output,
intermediate_tensors))
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp,
padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=self.moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
attn_metadata, self.with_prefill,
padded_num_tokens_across_dp, input_ids, positions,
intermediate_tensors, inputs_embeds)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
kv_connector_output = None
if finished_sending is not None or finished_recving is not None:
kv_connector_output = KVConnectorOutput(
@@ -1667,10 +1613,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np,
finished_sending, finished_recving,
kv_connector_output)
return self._pool(
hidden_states,
scheduler_output.total_num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
@@ -1746,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
hidden_states[:scheduler_output.total_num_scheduled_tokens],
scheduler_output,
)
@@ -1796,7 +1743,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
scheduler_output.total_num_scheduled_tokens,
hidden_states,
attn_metadata,
aux_hidden_states,
@@ -2191,72 +2138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30))
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
@@ -2707,7 +2588,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self.torchair_graph_enabled,
is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(),
)
draft_token_ids = self.drafter.propose(
@@ -2818,72 +2699,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return prompt_logprobs_dict
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def get_supported_pooling_tasks(self):
model = self.get_model()
if not is_pooling_model(model):
return []
return list(model.pooler.get_supported_tasks())
def _build_drafter_prepare_inputs_torchair_param(self):
return False

View File

@@ -48,6 +48,8 @@ class MtpProposer:
device=self.runner.device)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
@staticmethod
def prepare_inputs(
@@ -136,7 +138,7 @@ class MtpProposer:
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None and self.runner.torchair_graph_enabled:
if token_indices is not None and self.torchair_graph_enabled:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids
@@ -154,7 +156,7 @@ class MtpProposer:
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
is_running_torchair = self.runner.torchair_graph_enabled and \
is_running_torchair = self.torchair_graph_enabled and \
not self.runner.with_prefill
if is_running_torchair:
@@ -193,7 +195,7 @@ class MtpProposer:
attn_metadata.prefill.input_positions = target_positions
attn_metadata.prefill.seq_lens = seq_lens
if not self.runner.torchair_graph_enabled:
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
@@ -216,7 +218,7 @@ class MtpProposer:
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.runner.torchair_graph_enabled:
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
@@ -280,12 +282,12 @@ class MtpProposer:
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
if not self.runner.torchair_graph_enabled:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
is_running_torchair = self.runner.torchair_graph_enabled and \
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
if is_running_torchair: