Fix some ci issue and refactor modelrunner (#2445)

### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.10.0
- vLLM main:
4d9c61993a

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
Mengqing Cao
2025-08-20 09:01:04 +08:00
committed by GitHub
parent 955411611c
commit 1327f9be1c
28 changed files with 1612 additions and 1020 deletions

View File

@@ -23,7 +23,6 @@ import math
import os
import time
import types
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
@@ -34,16 +33,21 @@ import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from tqdm import tqdm # type: ignore
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import DPMetadata, get_forward_context
get_tp_group,
is_global_first_rank)
from vllm.forward_context import (BatchDescriptor, DPMetadata,
get_forward_context)
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -55,15 +59,17 @@ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, SupportedTask
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv)
LazyLoader, cdiv, is_pin_memory_available)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -79,6 +85,8 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MoECommMethod)
@@ -154,8 +162,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.load_config = vllm_config.load_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.pin_memory = is_pin_memory_available()
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.block_size = vllm_config.cache_config.block_size
@@ -215,7 +226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
use_mla=self.model_config.use_mla,
)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
@@ -228,13 +239,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q = []
self.spec_token_num = 0
self.decode_token_per_req = 1
if self.speculative_config:
self.use_spec_decode = True
self.spec_token_num = self.speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.decode_token_per_req = 1 + self.spec_token_num
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.actual_seq_lengths_q = [
len for len in
range(self.decode_token_per_req, self.max_num_tokens +
@@ -331,13 +341,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.use_aclgraph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager and
not ascend_config.torchair_graph_config.enabled)
self.use_aclgraph = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.level == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager
and not ascend_config.torchair_graph_config.enabled)
self.aclgraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
reversed(self.compilation_config.cudagraph_capture_sizes))
self.uniform_decode_query_len = 1 if not self.speculative_config else \
1 + self.speculative_config.num_speculative_tokens
# aclgraph dispatcher for runtime aclgraph dispatching.
self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config)
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
@@ -405,12 +423,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
The SamplingMetadata is updated and copied to the NPU if there is a
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
@@ -421,11 +433,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: List[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
self.input_batch.remove_request(req_id)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
@@ -448,16 +457,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
req_index = self.input_batch.remove_request(req_id)
assert req_index is not None
removed_req_indices.append(req_index)
self.input_batch.remove_request(req_id)
req_ids_to_add: List[str] = []
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
@@ -468,7 +476,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if pooling_params:
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
@@ -478,7 +486,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=new_req_data.pooling_params,
pooling_params=pooling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
@@ -493,9 +501,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
for mm_item in self.requests[req_id].mm_kwargs:
mm_input = mm_item.get_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())
@@ -528,19 +535,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
req_data = scheduler_output.scheduled_cached_reqs
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
@@ -549,11 +561,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip( # type: ignore[call-overload]
req_state.block_ids, new_block_ids):
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
@@ -571,9 +584,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
@@ -583,9 +597,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
@@ -595,39 +611,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices.sort(reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
req_index = self.input_batch.num_reqs - 1
start_index = len(req_state.prompt_token_ids) + len(
req_state.output_token_ids)
end_token_index = start_index + len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
self.input_batch.num_tokens[req_index] = end_token_index
self.input_batch.add_request(req_state)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
if batch_changed:
self.input_batch.refresh_sampling_metadata()
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _get_forward_metadata_across_dp(
self, num_tokens: int, with_prefill: bool,
@@ -798,17 +792,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
attn_metadata_i = self.attn_metadata_builder.build(
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=self.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.slot_mapping_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata_i = self.attn_metadata_builder.build(
common_attn_metadata, self.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def get_model(self) -> nn.Module:
# get raw model out of the aclgraph wrapper.
if isinstance(self.model, ACLGraphWrapper):
return self.model.unwrap()
return self.model
def get_supported_generation_tasks(self) -> "list[GenerationTask]":
@@ -1063,11 +1074,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens)
num_input_tokens += num_pad
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
@@ -1168,8 +1176,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
extra_builder_kwargs = {}
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc[:num_reqs + 1].copy_(
@@ -1186,45 +1192,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
]
is_only_prefill = bool(np.all(num_valid_tokens != 1))
extra_builder_kwargs['is_only_prefill'] = is_only_prefill
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
total_num_scheduled_tokens, with_prefill, enable_dbo)
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
if self.torchair_graph_enabled and not with_prefill:
self.graph_pad_size = padded_num_tokens_across_dp
extra_builder_kwargs[
'graph_pad_size'] = self.graph_pad_size # type: ignore
else:
self.graph_pad_size = -1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=self.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.slot_mapping_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=is_only_prefill,
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata = self.attn_metadata_builder.build(
common_attn_metadata, self.model)
if self.vllm_config.model_config.use_mla:
extra_builder_kwargs[
"query_start_loc"] = self.query_start_loc[:num_reqs + 1]
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
else:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
**extra_builder_kwargs,
)
# Prepare input_ids
token_indices = (positions_np +
@@ -1534,7 +1539,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
return logits.to(self.device).to(logits_dtype)
def _get_spec_token_ids(
def propose_draft_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
@@ -1549,23 +1554,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
) -> Optional[list[list[int]]]:
if not self.use_spec_decode:
# Speculative decoding is not enabled.
spec_token_ids = None
draft_token_ids = None
elif self.speculative_config.method == "ngram":
spec_token_ids = self._generate_ngram_token_ids(
draft_token_ids = self._generate_ngram_token_ids(
valid_sampled_token_ids)
elif self.speculative_config.method == "eagle":
raise NotImplementedError("Eagle Is Not Supported Yet.")
elif self.speculative_config.method == "eagle3":
spec_token_ids = self._generate_eagle3_token_ids(
draft_token_ids = self._generate_eagle3_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, aux_hidden_states)
elif self.speculative_config.method == 'deepseek_mtp':
spec_token_ids = self._generate_mtp_token_ids(
draft_token_ids = self._generate_mtp_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, attn_metadata)
return spec_token_ids
return draft_token_ids
def _pool(
self,
@@ -1606,7 +1611,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
@@ -1785,17 +1789,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
hidden_states,
attn_metadata,
aux_hidden_states,
)
if self.speculative_config:
self._draft_token_ids = self.propose_draft_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
hidden_states,
attn_metadata,
aux_hidden_states,
)
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
@@ -1806,7 +1811,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
@@ -1825,6 +1829,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.input_batch.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_ascend_forward_context(None, self.vllm_config):
@@ -1898,30 +1913,66 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
) -> torch.Tensor:
# only support eager mode and piecewise graph now
assert aclgraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE
}
if force_attention:
raise RuntimeError(
"Capturing attention in aclgraph is unexpected, because full graph is not supported now"
)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
# uniform decode batches. A uniform decode batch means that all
# requests have identical query length, except a potential virtual
# request (shorter) in the batch account for padding.
# Uniform decode batch could either be common pure decode, where
# max_query_len == 1, or speculative decode, where
# max_query_len == 1 + num_spec_decode_tokens.
# When setting max_query_len = 1, we switch to and capture the optimized
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
if with_prefill:
num_reqs = num_tokens
if uniform_decode:
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
else:
num_reqs = (num_tokens + self.decode_token_per_req -
1) // self.decode_token_per_req
num_reqs = min(num_reqs, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
if with_prefill:
num_reqs = num_tokens
else:
num_reqs = (num_tokens + self.decode_token_per_req -
1) // self.decode_token_per_req
num_reqs = min(num_reqs, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
@@ -1931,8 +1982,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
attn_metadata = self._build_attention_metadata(with_prefill, num_reqs,
skip_attn)
attn_metadata = self._build_attention_metadata(with_prefill,
num_reqs,
skip_attn=True)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@@ -1961,6 +2013,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
if aclgraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert aclgraph_runtime_mode == _cg_mode, (
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.")
with set_ascend_forward_context(
attn_metadata,
@@ -1973,7 +2037,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
moe_comm_method=moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
num_actual_tokens=0,
):
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
@@ -1983,7 +2048,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(
num_tokens=num_tokens,
with_prefill=with_prefill,
skip_attn=skip_attn,
skip_attn=True,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
@@ -2026,53 +2091,71 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache.clear()
gc.collect()
@torch.inference_mode()
def _dummy_pooler_run(
def _dummy_pooler_run_task(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
task: PoolingTask,
) -> PoolerOutput:
num_tokens = hidden_states.shape[0]
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
req_num_tokens = num_tokens // num_reqs
model = cast(VllmModelForPooling, self.model)
dummy_task = self.get_supported_pooling_tasks()[0]
dummy_pooling_params = PoolingParams(task=dummy_task)
dummy_prompt_lens = torch.tensor(
[h.shape[0] for h in hidden_states_list],
device=self.device,
)
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device)
to_update = model.pooler.get_pooling_updates(dummy_task)
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[dummy_pooling_params] * num_reqs)
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
)
try:
pooler_output = model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
return model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"NPU out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"NPU out of memory occurred when warming up pooler "
f"({task=}) with {num_reqs} dummy requests. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return pooler_output
@torch.inference_mode()
def _dummy_pooler_run(
self,
hidden_states: torch.Tensor,
) -> PoolerOutput:
# Find the task that has the largest output for subsequent steps
output_size = dict[PoolingTask, float]()
for task in self.get_supported_pooling_tasks():
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = output.get_data_nbytes()
del output # Allow GC
max_task = max(output_size.items(), key=lambda x: x[1])[0]
return self._dummy_pooler_run_task(hidden_states, max_task)
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
@@ -2199,10 +2282,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
)
kv_cache_sizes = {}
@@ -2315,10 +2403,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
@@ -2332,7 +2419,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
forward_ctx = self.compilation_config.static_forward_context
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
@@ -2361,30 +2448,82 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def initialize_aclgraph_capture(self) -> None:
# TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported
# Trigger aclgraph dispatching keys initialization here (after
# initializing attn backends).
self.aclgraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def _capture_aclgraphs(self, compilation_cases: list[int],
aclgraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing ACL graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
aclgraph_runtime_mode.name))
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
moe_comm_method=self.moe_comm_method)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=uniform_decode,
moe_comm_method=self.moe_comm_method)
def _capture_model(self):
if not self.use_aclgraph:
logger.info("Skipping NPU graph capture for eager mode.")
logger.warning(
"Skipping ACL graph capture. To turn on ACL graph capture, "
"ensure `aclraph_mode` was not manually set to `NONE`")
return
else:
self.initialize_aclgraph_capture()
set_cudagraph_capturing_enabled(True)
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
aclgraph_mode = self.compilation_config.cudagraph_mode
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
self._capture_aclgraphs(
compilation_cases,
aclgraph_runtime_mode=aclgraph_runtime_mode,
uniform_decode=False)
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
def capture_model(self) -> None:
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]