[3/N][Refactor] torchair model runner refactor (#2207)
There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203, this is the first PR.
What's this PR do:
create common function `_build_attention_metadata` and
`_generate_dummy_run_hidden_states` for dummy_run
- vLLM version: v0.10.0
- vLLM main:
ebf7605b0d
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -21,7 +21,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||||
|
maybe_converting_weight_acl_format)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -55,3 +58,58 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
maybe_padded_num_tokens = num_tokens
|
maybe_padded_num_tokens = num_tokens
|
||||||
|
|
||||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||||
|
|
||||||
|
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
|
||||||
|
# NOTE: If torchair graph mode and not with_prefill,
|
||||||
|
# we can't skip_attn, it will cause graph recompile.
|
||||||
|
if not with_prefill:
|
||||||
|
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||||
|
num_reqs=num_reqs, num_actual_tokens=1)
|
||||||
|
else:
|
||||||
|
attn_metadata = super()._build_attention_metadata(
|
||||||
|
with_prefill, num_reqs, skip_attn)
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||||
|
is_torchair_compile, input_ids,
|
||||||
|
positions, attn_metadata, num_tokens,
|
||||||
|
intermediate_tensors, inputs_embeds):
|
||||||
|
|
||||||
|
if not with_prefill:
|
||||||
|
# Only mark static while compiling
|
||||||
|
if is_torchair_compile:
|
||||||
|
torch._dynamo.mark_static(input_ids)
|
||||||
|
torch._dynamo.mark_static(positions)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||||
|
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||||
|
if hasattr(attn_metadata.decode, "sin"):
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||||
|
if self.speculative_config:
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||||
|
for kv in self.kv_caches:
|
||||||
|
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||||
|
torch._dynamo.mark_static(kv[0])
|
||||||
|
torch._dynamo.mark_static(kv[1])
|
||||||
|
|
||||||
|
maybe_converting_weight_acl_format(self.model,
|
||||||
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||||
|
model_kwargs = {}
|
||||||
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
|
hidden_states = compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=None,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||||
|
with_prefill, is_torchair_compile, input_ids, positions,
|
||||||
|
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@@ -1832,6 +1832,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output.finished_req_ids)
|
scheduler_output.finished_req_ids)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
|
||||||
|
if skip_attn:
|
||||||
|
attn_metadata = None
|
||||||
|
else:
|
||||||
|
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||||
|
attn_metadata = None
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||||
|
is_torchair_compile, input_ids,
|
||||||
|
positions, attn_metadata, num_tokens,
|
||||||
|
intermediate_tensors, inputs_embeds):
|
||||||
|
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||||
|
hidden_states = self.model(input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
hidden_states, _ = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states
|
||||||
|
if self.use_spec_decode and isinstance(self.drafter, EagleProposer):
|
||||||
|
self.drafter.dummy_run(num_tokens)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@@ -1868,20 +1893,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.is_kv_producer:
|
if self.is_kv_producer:
|
||||||
with_prefill = True
|
with_prefill = True
|
||||||
|
|
||||||
# NOTE: If torchair graph mode and not with_prefill,
|
attn_metadata = self._build_attention_metadata(with_prefill, num_reqs,
|
||||||
# we can't skip_attn, it will cause graph recompile.
|
skip_attn)
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
|
||||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
|
||||||
num_reqs=num_reqs, num_actual_tokens=1)
|
|
||||||
elif skip_attn:
|
|
||||||
attn_metadata = None
|
|
||||||
else:
|
|
||||||
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
|
||||||
attn_metadata = None
|
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
model = self.model
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
@@ -1917,61 +1933,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
in_profile_run=self.in_profile_run,
|
in_profile_run=self.in_profile_run,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
):
|
):
|
||||||
model_kwargs = {}
|
hidden_states = self._generate_dummy_run_hidden_states(
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
with_prefill, is_torchair_compile, input_ids, positions,
|
||||||
# Only mark static while compiling
|
attn_metadata, num_tokens, intermediate_tensors,
|
||||||
if is_torchair_compile:
|
inputs_embeds)
|
||||||
torch._dynamo.mark_static(input_ids)
|
|
||||||
torch._dynamo.mark_static(positions)
|
|
||||||
torch._dynamo.mark_static(
|
|
||||||
attn_metadata.decode.block_table)
|
|
||||||
torch._dynamo.mark_static(
|
|
||||||
attn_metadata.decode.input_positions)
|
|
||||||
torch._dynamo.mark_static(
|
|
||||||
get_forward_context().mc2_mask)
|
|
||||||
if hasattr(attn_metadata.decode, "sin"):
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
|
||||||
if self.speculative_config:
|
|
||||||
torch._dynamo.mark_static(
|
|
||||||
attn_metadata.decode.attn_mask)
|
|
||||||
for kv in self.kv_caches:
|
|
||||||
assert isinstance(
|
|
||||||
kv, tuple), "kv_cache must be a tuple"
|
|
||||||
torch._dynamo.mark_static(kv[0])
|
|
||||||
torch._dynamo.mark_static(kv[1])
|
|
||||||
|
|
||||||
maybe_converting_weight_acl_format(self.model,
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
|
||||||
num_tokens)
|
|
||||||
model_kwargs["kv_caches"] = self.kv_caches
|
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
|
||||||
hidden_states = compiled_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=None,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
maybe_converting_weight_acl_format(self.model,
|
|
||||||
ACL_FORMAT_FRACTAL_ND)
|
|
||||||
|
|
||||||
hidden_states = model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds)
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
hidden_states, _ = hidden_states
|
|
||||||
else:
|
|
||||||
hidden_states = hidden_states
|
|
||||||
if self.use_spec_decode and isinstance(
|
|
||||||
self.drafter, EagleProposer):
|
|
||||||
self.drafter.dummy_run(num_tokens)
|
|
||||||
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
|
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
|
||||||
assert isinstance(self.drafter, MtpProposer)
|
assert isinstance(self.drafter, MtpProposer)
|
||||||
self.drafter.dummy_run(
|
self.drafter.dummy_run(
|
||||||
|
|||||||
Reference in New Issue
Block a user