Fix some ci issue and refactor modelrunner (#2445)

### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.10.0
- vLLM main:
4d9c61993a

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
Mengqing Cao
2025-08-20 09:01:04 +08:00
committed by GitHub
parent 955411611c
commit 1327f9be1c
28 changed files with 1612 additions and 1020 deletions

View File

@@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@@ -17,11 +18,14 @@ from vllm.utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor)
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
@@ -172,20 +176,24 @@ class AscendMLAMetadataBuilder:
# _attn_mask_builder = None
def __init__(self,
runner,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
self.block_size = runner.block_size
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * model_config.max_model_len,
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
@@ -200,13 +208,13 @@ class AscendMLAMetadataBuilder:
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
@@ -220,8 +228,6 @@ class AscendMLAMetadataBuilder:
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
@@ -231,18 +237,14 @@ class AscendMLAMetadataBuilder:
if self.torchair_graph_enabled:
if num_tokens - num_spec_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# For eager mode we treat spec decoding as chunked prefill.
else:
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
@@ -273,26 +275,15 @@ class AscendMLAMetadataBuilder:
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _get_graph_runner_block_tables(
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
max_blocks = self.max_blocks
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
if isinstance(self.runner.graph_block_tables, np.ndarray):
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
else:
graph_block_tables = self.runner.graph_block_tables.to(
device=block_tables.device, dtype=block_tables.dtype)
graph_block_tables = torch.zeros((num_seqs, max_blocks),
dtype=block_tables.dtype,
device=block_tables.device)
num_blocks = block_tables.size(1)
if num_blocks <= max_blocks:
@@ -304,18 +295,20 @@ class AscendMLAMetadataBuilder:
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables[:num_seqs, :max_blocks]
return graph_block_tables[:, :max_blocks]
def build_torchair_graph_dummy(
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
self,
common_attn_metadata: TorchairCommonAttentionMetadata,
) -> AscendMLAMetadata:
device = self.device
num_reqs = common_attn_metadata.num_reqs
block_table = torch.zeros((num_reqs, self.max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
num_tokens = num_reqs * self.runner.decode_token_per_req
num_tokens = num_reqs * common_attn_metadata.decode_token_per_req
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
seq_lens_list = [0] * num_reqs
input_positions = torch.zeros(num_tokens,
@@ -333,16 +326,16 @@ class AscendMLAMetadataBuilder:
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
dtype=self.model_config.dtype,
device=device)
cos = torch.ones(num_tokens,
1,
1,
self.rope_dim,
dtype=self.runner.dtype,
dtype=self.model_config.dtype,
device=device)
if self.runner.speculative_config is not None and\
self.runner.speculative_config.method == 'deepseek_mtp':
if self.vllm_config.speculative_config is not None and\
self.vllm_config.speculative_config.method == 'deepseek_mtp':
attn_state = AscendAttentionState.SpecDecoding
num_decode_tokens = 2
else:
@@ -354,20 +347,21 @@ class AscendMLAMetadataBuilder:
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=1,
attn_mask=self.runner.spec_attn_mask,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs],
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=common_attn_metadata.
actual_seq_lengths_q[:num_reqs],
sin=sin,
cos=cos,
)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
num_input_tokens=common_attn_metadata.num_actual_tokens,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
head_dim=self.model_config.get_head_size(),
num_decodes=1,
num_decode_tokens=num_decode_tokens,
num_prefills=0,
attn_mask=self.runner.attn_mask,
attn_mask=common_attn_metadata.attn_mask,
attn_state=attn_state,
prefill=None,
decode=decode_metadata,
@@ -378,58 +372,68 @@ class AscendMLAMetadataBuilder:
def build(
self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
graph_pad_size: int = -1,
query_start_loc: torch.Tensor = None,
enable_dbo_across_dp: bool = False,
*args,
**kwargs,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
if self.torchair_graph_enabled and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
decode_threshold = common_attn_metadata.decode_token_per_req
else:
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
device = self.device
block_table = (self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
num_reqs]
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
if self.cos_cache is None:
self.cos_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.cos_cached
self.sin_cache = self.runner.get_model(
).model.layers[0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
self.cos_cache = model.model.layers[
0].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.runner.dtype) # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.runner.dtype) # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
prefill_metadata = None
chunked_context_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
reqs_start:num_reqs]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
@@ -441,12 +445,12 @@ class AscendMLAMetadataBuilder:
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
@@ -470,7 +474,7 @@ class AscendMLAMetadataBuilder:
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[tokens_start:],
seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:],
@@ -485,14 +489,15 @@ class AscendMLAMetadataBuilder:
)
decode_metadata = None
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size != -1
if self._num_decodes > 0:
if num_decodes > 0:
actual_seq_lengths_q = query_start_loc[1:].tolist()
max_seq_lens = seq_lens[:self._num_decodes].max().item()
seq_lens = seq_lens[:self._num_decode_tokens]
input_positions = input_positions[:self._num_decode_tokens]
block_table = block_table[:self._num_decode_tokens, ...]
if use_torchair_graph and self.runner.attn_state in [
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decode_tokens]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decode_tokens, ...]
if use_torchair_graph and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
@@ -500,10 +505,10 @@ class AscendMLAMetadataBuilder:
num_token_pad_size = 0
if graph_pad_size != 0:
pad_value = 0
num_token_pad_size = graph_pad_size - self._num_decode_tokens
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size // self.runner.decode_token_per_req -
num_reqs)
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
padded_seq_lens = seq_lens.tolist(
) + [pad_value] * num_reqs_pad_size
else:
@@ -531,14 +536,14 @@ class AscendMLAMetadataBuilder:
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = query_start_loc[1:].tolist(
) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs +
num_reqs_pad_size]
) + common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
else:
seq_lens_list = seq_lens.tolist()
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
batch_size = slot_mapping.size(0)
if actual_seq_lengths_q[-1] != batch_size \
and self.runner.attn_state == AscendAttentionState.SpecDecoding:
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
actual_seq_lengths_q[-1] = batch_size
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
@@ -552,7 +557,7 @@ class AscendMLAMetadataBuilder:
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=self.runner.spec_attn_mask,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
@@ -561,18 +566,18 @@ class AscendMLAMetadataBuilder:
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
attn_mask=self.runner.attn_mask,
attn_state=self.runner.attn_state,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=enable_dbo_across_dp,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)