refact runner model v1 (#2461)
refact model runner v1
### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1
- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -48,6 +48,8 @@ class MtpProposer:
|
||||
device=self.runner.device)
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.torchair_graph_enabled = get_ascend_config(
|
||||
).torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
@@ -136,7 +138,7 @@ class MtpProposer:
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
if token_indices is not None and self.runner.torchair_graph_enabled:
|
||||
if token_indices is not None and self.torchair_graph_enabled:
|
||||
last_token_indices = token_indices
|
||||
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
@@ -154,7 +156,7 @@ class MtpProposer:
|
||||
# input_batch=self.runner.input_batch,
|
||||
# scheduler_output=self.runner.scheduler_output,
|
||||
# )
|
||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not self.runner.with_prefill
|
||||
|
||||
if is_running_torchair:
|
||||
@@ -193,7 +195,7 @@ class MtpProposer:
|
||||
attn_metadata.prefill.input_positions = target_positions
|
||||
attn_metadata.prefill.seq_lens = seq_lens
|
||||
|
||||
if not self.runner.torchair_graph_enabled:
|
||||
if not self.torchair_graph_enabled:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -216,7 +218,7 @@ class MtpProposer:
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.runner.torchair_graph_enabled:
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if is_running_torchair:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
@@ -280,12 +282,12 @@ class MtpProposer:
|
||||
skip_attn: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None) -> None:
|
||||
if not self.runner.torchair_graph_enabled:
|
||||
if not self.torchair_graph_enabled:
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
||||
num_tokens, with_prefill, False)
|
||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not with_prefill
|
||||
|
||||
if is_running_torchair:
|
||||
|
||||
Reference in New Issue
Block a user