[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -34,7 +34,6 @@ import torch
|
||||
import torch._dynamo.cache_size
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||
self.in_profile_run = False
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self, total_num_scheduled_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, bool]:
|
||||
forward_metadata = torch.tensor(
|
||||
[total_num_scheduled_tokens, with_prefill],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata,
|
||||
op=ReduceOp.MAX,
|
||||
group=get_dp_group().cpu_group)
|
||||
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
||||
self,
|
||||
maybe_padded_num_tokens: int,
|
||||
num_tokens: int,
|
||||
with_prefill: bool,
|
||||
enable_dbo: bool = False,
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
if self.dp_size == 1:
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp = [0] * self.dp_size * 2
|
||||
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
|
||||
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
|
||||
forward_metadata = torch.tensor(num_tokens_across_dp +
|
||||
[with_prefill, not enable_dbo],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
|
||||
with_prefill = bool(forward_metadata[-2])
|
||||
|
||||
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
|
||||
if with_prefill:
|
||||
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
|
||||
2]
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
else:
|
||||
num_tokens_across_dp = forward_metadata[:self.dp_size]
|
||||
|
||||
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
|
||||
# `max_tokens_across_dp`, in other situation it is not necessary.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
|
||||
self.dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||
forward_metadata[-1])
|
||||
|
||||
def get_eagle_atten_dict(
|
||||
self,
|
||||
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
total_num_scheduled_tokens, with_prefill)
|
||||
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||
|
||||
# Add graph_pad_size here
|
||||
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
else:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
total_num_scheduled_tokens)
|
||||
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
total_num_scheduled_tokens)
|
||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
})
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=padded_num_tokens_across_dp,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
num_actual_tokens=total_num_scheduled_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
model_kwargs = {}
|
||||
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_batch_size)
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfer(scheduler_output))
|
||||
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
is_compile: bool = False,
|
||||
with_prefill: bool = True,
|
||||
skip_attn: bool = True,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
) -> torch.Tensor:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, num_tokens, with_prefill, False)
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer:
|
||||
with_prefill = True
|
||||
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
elif skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||
attn_metadata = None
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
num_actual_tokens=0,
|
||||
):
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
# Only mark static while compiling
|
||||
if is_compile:
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(
|
||||
get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(
|
||||
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
def set_in_profile_run(self):
|
||||
self.in_profile_run = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.in_profile_run = False
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
with self.set_in_profile_run():
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
with_prefill=True)
|
||||
output = None
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.is_pooling_model:
|
||||
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens,
|
||||
is_compile=True,
|
||||
with_prefill=False)
|
||||
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
|
||||
Reference in New Issue
Block a user