[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -34,7 +34,6 @@ import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import get_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
# kv role
self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None:
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self,
maybe_padded_num_tokens: int,
num_tokens: int,
with_prefill: bool,
enable_dbo: bool = False,
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp = [0] * self.dp_size * 2
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
forward_metadata = torch.tensor(num_tokens_across_dp +
[with_prefill, not enable_dbo],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
with_prefill = bool(forward_metadata[-2])
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
if with_prefill:
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
2]
maybe_padded_num_tokens = num_tokens
else:
num_tokens_across_dp = forward_metadata[:self.dp_size]
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
# `max_tokens_across_dp`, in other situation it is not necessary.
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
def get_eagle_atten_dict(
self,
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
maybe_padded_num_tokens = total_num_scheduled_tokens
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_input_tokens]
if self.torchair_graph_enabled and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
})
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config):
with set_ascend_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output))
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
) -> torch.Tensor:
maybe_padded_num_tokens = num_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, num_tokens, with_prefill, False)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
num_actual_tokens=0,
):
model_kwargs = {}
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
if is_compile:
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
**model_kwargs,
)
else:
maybe_converting_weight_acl_format(self.model,
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens)
return hidden_states
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
def profile_run(self) -> None:
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
with self.set_in_profile_run():
hidden_states = self._dummy_run(self.max_num_tokens,
with_prefill=True)
output = None
if get_pp_group().is_last_rank:
if self.is_pooling_model:
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.