[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -7,13 +7,13 @@ from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
@@ -142,9 +142,9 @@ class EagleProposer:
|
||||
self.positions[:num_tokens] = target_positions.to(device)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
attn_metadata.block_tables = block_table.to(device)
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
@@ -239,9 +239,9 @@ class EagleProposer:
|
||||
attn_metadata.attn_mask = attn_mask
|
||||
attn_metadata.block_tables = block_table.to(device)
|
||||
# Run the model.
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
@@ -344,8 +344,9 @@ class EagleProposer:
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
with set_ascend_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
|
||||
@@ -34,7 +34,6 @@ import torch
|
||||
import torch._dynamo.cache_size
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||
self.in_profile_run = False
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self, total_num_scheduled_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, bool]:
|
||||
forward_metadata = torch.tensor(
|
||||
[total_num_scheduled_tokens, with_prefill],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata,
|
||||
op=ReduceOp.MAX,
|
||||
group=get_dp_group().cpu_group)
|
||||
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
||||
self,
|
||||
maybe_padded_num_tokens: int,
|
||||
num_tokens: int,
|
||||
with_prefill: bool,
|
||||
enable_dbo: bool = False,
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
if self.dp_size == 1:
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp = [0] * self.dp_size * 2
|
||||
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
|
||||
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
|
||||
forward_metadata = torch.tensor(num_tokens_across_dp +
|
||||
[with_prefill, not enable_dbo],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
|
||||
with_prefill = bool(forward_metadata[-2])
|
||||
|
||||
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
|
||||
if with_prefill:
|
||||
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
|
||||
2]
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
else:
|
||||
num_tokens_across_dp = forward_metadata[:self.dp_size]
|
||||
|
||||
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
|
||||
# `max_tokens_across_dp`, in other situation it is not necessary.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
|
||||
self.dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||
forward_metadata[-1])
|
||||
|
||||
def get_eagle_atten_dict(
|
||||
self,
|
||||
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
total_num_scheduled_tokens, with_prefill)
|
||||
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||
|
||||
# Add graph_pad_size here
|
||||
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
else:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
total_num_scheduled_tokens)
|
||||
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
total_num_scheduled_tokens)
|
||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
})
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=padded_num_tokens_across_dp,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
num_actual_tokens=total_num_scheduled_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
model_kwargs = {}
|
||||
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_batch_size)
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
with set_ascend_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfer(scheduler_output))
|
||||
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
is_compile: bool = False,
|
||||
with_prefill: bool = True,
|
||||
skip_attn: bool = True,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
) -> torch.Tensor:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, num_tokens, with_prefill, False)
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer:
|
||||
with_prefill = True
|
||||
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
elif skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
|
||||
attn_metadata = None
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
num_actual_tokens=0,
|
||||
):
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
# Only mark static while compiling
|
||||
if is_compile:
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(
|
||||
get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(
|
||||
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
def set_in_profile_run(self):
|
||||
self.in_profile_run = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.in_profile_run = False
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
with self.set_in_profile_run():
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
with_prefill=True)
|
||||
output = None
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.is_pooling_model:
|
||||
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens,
|
||||
is_compile=True,
|
||||
with_prefill=False)
|
||||
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
|
||||
@@ -2,12 +2,12 @@ import torch
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ class MtpProposer:
|
||||
query_start_loc=cu_num_tokens,
|
||||
)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
with set_ascend_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=target_positions,
|
||||
|
||||
@@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
|
||||
try_register_lib, vllm_version_is)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if not vllm_version_is("0.10.0"):
|
||||
@@ -134,6 +135,7 @@ class NPUWorker(WorkerBase):
|
||||
NPUPlatform.empty_cache()
|
||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||
|
||||
init_ascend_soc_version()
|
||||
# Initialize the distributed environment.
|
||||
self._init_worker_distributed_environment()
|
||||
# Set random seed.
|
||||
@@ -272,20 +274,8 @@ class NPUWorker(WorkerBase):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def _get_max_num_tokens_and_with_prefill(self):
|
||||
max_num_tokens = 1
|
||||
with_prefill = False
|
||||
if self.model_runner.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
|
||||
max_num_tokens, with_prefill)
|
||||
return max_num_tokens, with_prefill
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
|
||||
)
|
||||
self.model_runner._dummy_run(max_num_tokens,
|
||||
is_compile=False,
|
||||
with_prefill=with_prefill)
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
@@ -295,6 +285,7 @@ class NPUWorker(WorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
init_ascend_model_parallel(self.parallel_config)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
def _init_profiler(self):
|
||||
|
||||
Reference in New Issue
Block a user