refact runner model v1 (#2461)

refact model runner v1

### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1

- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-08-21 08:54:57 +08:00
committed by GitHub
parent 1de16ead8e
commit 0dca4c6dbd
3 changed files with 368 additions and 307 deletions

View File

@@ -48,6 +48,8 @@ class MtpProposer:
device=self.runner.device)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = get_ascend_config(
).torchair_graph_config.enabled
@staticmethod
def prepare_inputs(
@@ -136,7 +138,7 @@ class MtpProposer:
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
if token_indices is not None and self.runner.torchair_graph_enabled:
if token_indices is not None and self.torchair_graph_enabled:
last_token_indices = token_indices
self.input_ids[last_token_indices] = next_token_ids
@@ -154,7 +156,7 @@ class MtpProposer:
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
is_running_torchair = self.runner.torchair_graph_enabled and \
is_running_torchair = self.torchair_graph_enabled and \
not self.runner.with_prefill
if is_running_torchair:
@@ -193,7 +195,7 @@ class MtpProposer:
attn_metadata.prefill.input_positions = target_positions
attn_metadata.prefill.seq_lens = seq_lens
if not self.runner.torchair_graph_enabled:
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
@@ -216,7 +218,7 @@ class MtpProposer:
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.runner.torchair_graph_enabled:
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
@@ -280,12 +282,12 @@ class MtpProposer:
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
if not self.runner.torchair_graph_enabled:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
is_running_torchair = self.runner.torchair_graph_enabled and \
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
if is_running_torchair: