[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -7,13 +7,13 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -142,9 +142,9 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
@@ -239,9 +239,9 @@ class EagleProposer:
attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device)
# Run the model.
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
@@ -344,8 +344,9 @@ class EagleProposer:
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
with set_ascend_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],

View File

@@ -34,7 +34,6 @@ import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import get_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
# kv role
self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None:
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self,
maybe_padded_num_tokens: int,
num_tokens: int,
with_prefill: bool,
enable_dbo: bool = False,
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp = [0] * self.dp_size * 2
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
forward_metadata = torch.tensor(num_tokens_across_dp +
[with_prefill, not enable_dbo],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
with_prefill = bool(forward_metadata[-2])
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
if with_prefill:
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
2]
maybe_padded_num_tokens = num_tokens
else:
num_tokens_across_dp = forward_metadata[:self.dp_size]
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
# `max_tokens_across_dp`, in other situation it is not necessary.
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
def get_eagle_atten_dict(
self,
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
maybe_padded_num_tokens = total_num_scheduled_tokens
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_input_tokens]
if self.torchair_graph_enabled and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
})
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config):
with set_ascend_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output))
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
) -> torch.Tensor:
maybe_padded_num_tokens = num_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, num_tokens, with_prefill, False)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
num_actual_tokens=0,
):
model_kwargs = {}
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
if is_compile:
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
**model_kwargs,
)
else:
maybe_converting_weight_acl_format(self.model,
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens)
return hidden_states
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
def profile_run(self) -> None:
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
with self.set_in_profile_run():
hidden_states = self._dummy_run(self.max_num_tokens,
with_prefill=True)
output = None
if get_pp_group().is_last_rank:
if self.is_pooling_model:
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.

View File

@@ -2,12 +2,12 @@ import torch
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
@@ -117,7 +117,7 @@ class MtpProposer:
query_start_loc=cu_num_tokens,
)
with set_forward_context(attn_metadata, self.vllm_config):
with set_ascend_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,

View File

@@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib,
vllm_version_is)
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
try_register_lib, vllm_version_is)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if not vllm_version_is("0.10.0"):
@@ -134,6 +135,7 @@ class NPUWorker(WorkerBase):
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
init_ascend_soc_version()
# Initialize the distributed environment.
self._init_worker_distributed_environment()
# Set random seed.
@@ -272,20 +274,8 @@ class NPUWorker(WorkerBase):
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def _get_max_num_tokens_and_with_prefill(self):
max_num_tokens = 1
with_prefill = False
if self.model_runner.dp_size > 1:
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
return max_num_tokens, with_prefill
def execute_dummy_batch(self) -> None:
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
)
self.model_runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)
self.model_runner._dummy_run(1)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
@@ -295,6 +285,7 @@ class NPUWorker(WorkerBase):
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self):