refact runner model v1 (#2461)

refact model runner v1

### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1

- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-08-21 08:54:57 +08:00
committed by GitHub
parent 1de16ead8e
commit 0dca4c6dbd
3 changed files with 368 additions and 307 deletions

View File

@@ -17,21 +17,29 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
# #
import types
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
check_torchair_cache_exist, check_torchair_cache_exist,
register_torchair_model, register_torchair_model,
write_kv_cache_bytes_to_file) write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format) is_310p, maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device) super().__init__(vllm_config, device)
ascend_config = get_ascend_config()
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self._check_batch_sizes_consistency()
register_torchair_model() register_torchair_model()
def _get_forward_metadata_across_dp_and_pad( def _get_forward_metadata_across_dp_and_pad(
@@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner):
if self.new_kv_cache_bytes > 0: if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(), write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes) self.new_kv_cache_bytes)
def _use_aclgraph(self) -> bool:
return False
def _check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
if not with_prefill:
self.graph_pad_size = graph_pad_size
else:
super()._update_graph_pad_size(with_prefill, graph_pad_size)
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
"""Override from NPUModelRunner to update input_ids and positions"""
input_ids, positions = super()._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)
if not with_prefill:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
return input_ids, positions
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
model_kwargs = {
"kv_caches": self.kv_caches,
"attn_metadata": attn_metadata
}
if not with_prefill:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return hidden_states
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def _build_drafter_prepare_inputs_torchair_param(self):
return True

View File

@@ -22,7 +22,6 @@ import gc
import math import math
import os import os
import time import time
import types
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
@@ -39,7 +38,6 @@ from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -108,7 +106,6 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
import torch_npu import torch_npu
import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
@@ -341,11 +338,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pin_memory=True) pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
self.use_aclgraph = ( self.use_aclgraph = self._use_aclgraph()
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.level == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager
and not ascend_config.torchair_graph_config.enabled)
self.aclgraph_batch_sizes = list( self.aclgraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes)) reversed(self.compilation_config.cudagraph_capture_sizes))
@@ -357,31 +350,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._draft_token_ids: Optional[Union[list[list[int]], self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None torch.Tensor]] = None
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.check_torchair_graph_batch_sizes()
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
self.graph_block_tables = np.zeros(
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
self.check_batch_sizes_consistency()
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False self.in_profile_run = False
@@ -400,27 +368,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.moe_comm_method = AllGatherCommImpl self.moe_comm_method = AllGatherCommImpl
def check_batch_sizes_consistency(self) -> None: def _use_aclgraph(self) -> bool:
if not dist.is_initialized(): return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
return
local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size
if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove finished requests from the cached states. # Remove finished requests from the cached states.
@@ -1047,14 +996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32) dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _process_reqs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata, ) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray, torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
Optional[set[str]], Optional[set[str]]]: Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]:
# Check input valid # Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
@@ -1103,9 +1053,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens = np.cumsum(num_scheduled_tokens) cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens) num_scheduled_tokens)
logits_indices = cu_num_tokens - 1
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
positions_np = self.positions_np[:total_num_scheduled_tokens] positions_np = self.positions_np[:total_num_scheduled_tokens]
@@ -1118,7 +1065,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope: if self.uses_mrope:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens], self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
@@ -1127,7 +1073,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_( self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = ( self.seq_lens_np[:num_reqs] = (
@@ -1145,34 +1091,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_offsets, block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens]) out=self.slot_mapping_np[:total_num_scheduled_tokens])
ascend_config = get_ascend_config() attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
use_spec_decode = len( num_valid_tokens)
scheduler_output.scheduled_spec_decode_tokens) > 0
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
self.attn_mask = self._make_attention_mask( self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens, seq_lens=seq_lens,
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,
position=positions, position=self.positions[:num_input_tokens],
attn_state=attn_state) attn_state=attn_state)
self.attn_state = attn_state # type: ignore self.attn_state = attn_state # type: ignore
@@ -1191,8 +1116,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
] ]
is_only_prefill = bool(np.all(num_valid_tokens != 1))
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state, attn_state,
total_num_scheduled_tokens) total_num_scheduled_tokens)
@@ -1202,10 +1125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
total_num_scheduled_tokens, with_prefill, enable_dbo) total_num_scheduled_tokens, with_prefill, enable_dbo)
self.with_prefill = with_prefill self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp self.num_tokens_across_dp = num_tokens_across_dp
if self.torchair_graph_enabled and not with_prefill: self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
self.graph_pad_size = padded_num_tokens_across_dp
else:
self.graph_pad_size = -1
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1221,7 +1141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo, enable_dbo_across_dp=enable_dbo,
is_only_prefill=is_only_prefill, is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size, graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req, decode_token_per_req=self.decode_token_per_req,
@@ -1248,10 +1168,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
@@ -1273,12 +1190,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the ACL graph. # then the embedding layer is not included in the ACL graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
if self.uses_mrope: positions = self.positions[:num_input_tokens]
positions = self.mrope_positions[:, :num_input_tokens] input_ids, positions = self._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
if self.torchair_graph_enabled and not with_prefill: padded_num_tokens_across_dp)
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
@@ -1293,8 +1208,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items() for k, v in self.intermediate_tensors.items()
}) })
moe_comm_method = self.moe_comm_method
# NOTE: Currently this padding logic is really messy, # NOTE: Currently this padding logic is really messy,
# MC2 may not be available in eager mode # MC2 may not be available in eager mode
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
@@ -1303,52 +1216,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else: else:
num_input_tokens = padded_num_tokens_across_dp num_input_tokens = padded_num_tokens_across_dp
# Run forward pass
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(self.device, self.dtype,
self.model_config.hf_config),
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
use_spec_decode = len( use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0 scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode: if not use_spec_decode:
@@ -1358,6 +1225,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
spec_decode_metadata = None spec_decode_metadata = None
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True)
else: else:
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all # Iterate over the dictionary rather than all requests since not all
@@ -1372,13 +1241,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens, cu_num_tokens) num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices logits_indices = spec_decode_metadata.logits_indices
aux_hidden_states = None return (attn_metadata, positions, num_scheduled_tokens,
if self.use_aux_hidden_state_outputs: num_input_tokens, num_tokens_across_dp,
hidden_states, aux_hidden_states = hidden_states padded_num_tokens_across_dp, logits_indices,
spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors)
return (attn_metadata, hidden_states, spec_decode_metadata, positions, def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
total_num_scheduled_tokens, logits_indices, aux_hidden_states, padded_num_tokens_across_dp,
num_scheduled_tokens, finished_sending, finished_recving) input_ids, positions,
intermediate_tensors,
inputs_embeds):
assert self.model is not None
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
return attn_state
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
self.graph_pad_size = -1
def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
return input_ids, positions
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
@@ -1623,8 +1540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
with ProfileExecuteDuration().capture_async( with ProfileExecuteDuration().capture_async("prepare input"):
"prepare input and forward"):
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group(): if not has_kv_transfer_group():
@@ -1634,11 +1550,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOuptut if there's no work to do. # Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output) return self.kv_connector_no_forward(scheduler_output)
(attn_metadata, hidden_states, spec_decode_metadata, positions, (attn_metadata, positions, num_scheduled_tokens_np,
num_scheduled_tokens, logits_indices, aux_hidden_states, num_input_tokens, num_tokens_across_dp,
num_scheduled_tokens_np, finished_sending, padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
finished_recving) = (self._process_reqs(scheduler_output, input_ids, inputs_embeds,
intermediate_tensors)) intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=self.moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
attn_metadata, self.with_prefill,
padded_num_tokens_across_dp, input_ids, positions,
intermediate_tensors, inputs_embeds)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
kv_connector_output = None kv_connector_output = None
if finished_sending is not None or finished_recving is not None: if finished_sending is not None or finished_recving is not None:
kv_connector_output = KVConnectorOutput( kv_connector_output = KVConnectorOutput(
@@ -1667,10 +1613,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = None logits = None
else: else:
if self.input_batch.pooling_params: if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens, return self._pool(
num_scheduled_tokens_np, hidden_states,
finished_sending, finished_recving, scheduler_output.total_num_scheduled_tokens,
kv_connector_output) num_scheduled_tokens_np, finished_sending,
finished_recving, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output: if broadcast_pp_output:
@@ -1746,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed. # Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict( prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens], hidden_states[:scheduler_output.total_num_scheduled_tokens],
scheduler_output, scheduler_output,
) )
@@ -1796,7 +1743,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output, scheduler_output,
spec_decode_metadata, spec_decode_metadata,
positions, positions,
num_scheduled_tokens, scheduler_output.total_num_scheduled_tokens,
hidden_states, hidden_states,
attn_metadata, attn_metadata,
aux_hidden_states, aux_hidden_states,
@@ -2191,72 +2138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",
m.consumed_memory / float(2**30)) m.consumed_memory / float(2**30))
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
def _convert_torch_format(self, tensor): def _convert_torch_format(self, tensor):
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor return tensor
@@ -2707,7 +2588,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions[:num_scheduled_tokens], positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens], attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self.torchair_graph_enabled, is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(),
) )
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
@@ -2818,72 +2699,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return prompt_logprobs_dict return prompt_logprobs_dict
def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
start_graph_batch_size = max(start_graph_batch_size, tp_size)
while (start_graph_batch_size <= self.max_num_reqs):
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def get_supported_pooling_tasks(self): def get_supported_pooling_tasks(self):
model = self.get_model() model = self.get_model()
if not is_pooling_model(model): if not is_pooling_model(model):
return [] return []
return list(model.pooler.get_supported_tasks()) return list(model.pooler.get_supported_tasks())
def _build_drafter_prepare_inputs_torchair_param(self):
return False

View File

@@ -48,6 +48,8 @@ class MtpProposer:
device=self.runner.device) device=self.runner.device)
self.torchair_compiled_model = None # type: ignore self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
@staticmethod @staticmethod
def prepare_inputs( def prepare_inputs(
@@ -136,7 +138,7 @@ class MtpProposer:
self.input_ids[:num_tokens - 1] = target_token_ids[1:] self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None and self.runner.torchair_graph_enabled: if token_indices is not None and self.torchair_graph_enabled:
last_token_indices = token_indices last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
@@ -154,7 +156,7 @@ class MtpProposer:
# input_batch=self.runner.input_batch, # input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output, # scheduler_output=self.runner.scheduler_output,
# ) # )
is_running_torchair = self.runner.torchair_graph_enabled and \ is_running_torchair = self.torchair_graph_enabled and \
not self.runner.with_prefill not self.runner.with_prefill
if is_running_torchair: if is_running_torchair:
@@ -193,7 +195,7 @@ class MtpProposer:
attn_metadata.prefill.input_positions = target_positions attn_metadata.prefill.input_positions = target_positions
attn_metadata.prefill.seq_lens = seq_lens attn_metadata.prefill.seq_lens = seq_lens
if not self.runner.torchair_graph_enabled: if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp # torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later # TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill, (num_input_tokens, num_tokens_across_dp, with_prefill,
@@ -216,7 +218,7 @@ class MtpProposer:
with ProfileExecuteDuration().capture_async('mtp_forward'): with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {} model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
if self.runner.torchair_graph_enabled: if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair: if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model( torchair_compiled_model = self._get_torchair_lazy_compiled_model(
@@ -280,12 +282,12 @@ class MtpProposer:
skip_attn: bool = False, skip_attn: bool = False,
num_reqs: int = 0, num_reqs: int = 0,
num_tokens_across_dp=None) -> None: num_tokens_across_dp=None) -> None:
if not self.runner.torchair_graph_enabled: if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later # TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill, (num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad( _) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False) num_tokens, with_prefill, False)
is_running_torchair = self.runner.torchair_graph_enabled and \ is_running_torchair = self.torchair_graph_enabled and \
not with_prefill not with_prefill
if is_running_torchair: if is_running_torchair: